1 | #pragma once |
2 | |
3 | #include <condition_variable> |
4 | #include <memory> |
5 | #include <type_traits> |
6 | #include <utility> |
7 | |
8 | #include <ATen/core/Dict.h> |
9 | #include <ATen/core/List.h> |
10 | #include <ATen/core/IListRef.h> |
11 | #include <ATen/core/functional.h> |
12 | #include <ATen/core/jit_type.h> |
13 | #include <ATen/core/qualified_name.h> |
14 | #include <ATen/core/rref_interface.h> |
15 | #include <ATen/core/symbol.h> |
16 | #include <c10/core/DeviceGuard.h> |
17 | #include <c10/core/Event.h> |
18 | #include <c10/core/Scalar.h> |
19 | #include <c10/core/Stream.h> |
20 | #include <c10/core/StreamGuard.h> |
21 | #include <c10/core/TensorImpl.h> |
22 | #include <c10/core/UndefinedTensorImpl.h> |
23 | #include <c10/core/impl/DeviceGuardImplInterface.h> |
24 | #include <c10/util/FunctionRef.h> |
25 | #include <c10/util/hash.h> |
26 | #include <c10/util/intrusive_ptr.h> |
27 | #include <c10/util/irange.h> |
28 | |
29 | namespace torch { |
30 | namespace jit { |
31 | struct Function; |
32 | struct CompilationUnit; |
33 | } // namespace jit |
34 | TORCH_API bool isCustomClass(const c10::IValue& v); |
35 | } // namespace torch |
36 | namespace c10 { |
37 | struct IValue; |
38 | struct ClassType; |
39 | struct TupleType; |
40 | struct EnumType; |
41 | struct InferredType; |
42 | |
43 | // For custom class __init__ registration, we need to pass in a function |
44 | // that looks like this: [](IValue x, args...) |
45 | |
46 | // However, make_boxed_from_unboxed_functor.h automatically sets the input types |
47 | // of the function by introspecting the types of the functor (which is IValue in |
48 | // this case). However, we need the type it binds to be Foo. |
49 | |
50 | // Instead, we pass in a lambda [](ivalue_holder<CurClass> x, args...) from |
51 | // which getTypePtr can recover the original class pointer. |
52 | |
53 | template <typename TaggedCapsuleType> |
54 | struct tagged_capsule { |
55 | IValue ivalue; |
56 | }; |
57 | |
58 | template <class T, class NullType> |
59 | c10::intrusive_ptr<T, NullType> IValue::moveToIntrusivePtr() { |
60 | auto t = c10::intrusive_ptr<T, NullType>::reclaim( |
61 | payload.u.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton() |
62 | ? NullType::singleton() |
63 | : static_cast<T*>(payload.u.as_intrusive_ptr)); |
64 | clearToNone(); |
65 | return t; |
66 | } |
67 | template <typename T, class NullType> |
68 | c10::intrusive_ptr<T, NullType> IValue::toIntrusivePtr() const { |
69 | if (payload.u.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton()) { |
70 | return c10::intrusive_ptr<T, NullType>(); |
71 | } |
72 | c10::raw::intrusive_ptr::incref(payload.u.as_intrusive_ptr); |
73 | return c10::intrusive_ptr<T, NullType>::reclaim( |
74 | static_cast<T*>(payload.u.as_intrusive_ptr)); |
75 | } |
76 | |
77 | template <class T, class U> |
78 | intrusive_ptr<T> static_intrusive_pointer_cast(intrusive_ptr<U> r) { |
79 | return intrusive_ptr<T>::reclaim(static_cast<T*>(r.release())); |
80 | } |
81 | |
82 | template <class T, class U> |
83 | intrusive_ptr<T> dynamic_intrusive_pointer_cast(intrusive_ptr<U> r) { |
84 | return intrusive_ptr<T>::reclaim(dynamic_cast<T*>(r.release())); |
85 | } |
86 | |
87 | inline c10::intrusive_ptr<ivalue::Future> IValue::toFuture() && { |
88 | AT_ASSERT(isFuture(), "Expected Future but got " , tagKind()); |
89 | return moveToIntrusivePtr<ivalue::Future>(); |
90 | } |
91 | inline c10::intrusive_ptr<ivalue::Future> IValue::toFuture() const& { |
92 | AT_ASSERT(isFuture(), "Expected Future but got " , tagKind()); |
93 | return toIntrusivePtr<ivalue::Future>(); |
94 | } |
95 | inline c10::intrusive_ptr<ivalue::Await> IValue::toAwait() && { |
96 | AT_ASSERT(isAwait(), "Expected Await but got " , tagKind()); |
97 | return moveToIntrusivePtr<ivalue::Await>(); |
98 | } |
99 | inline c10::intrusive_ptr<ivalue::Await> IValue::toAwait() const& { |
100 | AT_ASSERT(isAwait(), "Expected Await but got " , tagKind()); |
101 | return toIntrusivePtr<ivalue::Await>(); |
102 | } |
103 | inline c10::intrusive_ptr<c10::RRefInterface> IValue::toRRef() && { |
104 | AT_ASSERT(isRRef(), "Expected RRef but got " , tagKind()); |
105 | return moveToIntrusivePtr<c10::RRefInterface>(); |
106 | } |
107 | inline c10::intrusive_ptr<c10::RRefInterface> IValue::toRRef() const& { |
108 | AT_ASSERT(isRRef(), "Expected RRef but got " , tagKind()); |
109 | return toIntrusivePtr<c10::RRefInterface>(); |
110 | } |
111 | inline c10::intrusive_ptr<at::Quantizer> IValue::toQuantizer() && { |
112 | AT_ASSERT(isQuantizer(), "Expected Quantizer but got " , tagKind()); |
113 | return moveToIntrusivePtr<at::Quantizer>(); |
114 | } |
115 | inline c10::intrusive_ptr<at::Quantizer> IValue::toQuantizer() const& { |
116 | AT_ASSERT(isQuantizer(), "Expected Quantizer but got " , tagKind()); |
117 | return toIntrusivePtr<at::Quantizer>(); |
118 | } |
119 | inline c10::intrusive_ptr<ivalue::ConstantString> IValue::toString() && { |
120 | AT_ASSERT(isString(), "Expected String but got " , tagKind()); |
121 | return moveToIntrusivePtr<ivalue::ConstantString>(); |
122 | } |
123 | inline c10::intrusive_ptr<ivalue::ConstantString> IValue::toString() const& { |
124 | AT_ASSERT(isString(), "Expected String but got " , tagKind()); |
125 | return toIntrusivePtr<ivalue::ConstantString>(); |
126 | } |
127 | inline c10::intrusive_ptr<ivalue::Object> IValue::toObject() && { |
128 | AT_ASSERT(isObject(), "Expected Object but got " , tagKind()); |
129 | return moveToIntrusivePtr<ivalue::Object>(); |
130 | } |
131 | inline c10::intrusive_ptr<ivalue::Object> IValue::toObject() const& { |
132 | AT_ASSERT(isObject(), "Expected Object but got " , tagKind()); |
133 | return toIntrusivePtr<ivalue::Object>(); |
134 | } |
135 | inline c10::intrusive_ptr<ivalue::PyObjectHolder> IValue:: |
136 | toPyObjectHolder() && { |
137 | TORCH_INTERNAL_ASSERT(isPyObject(), "Expected PyObject but got " , tagKind()); |
138 | return moveToIntrusivePtr<ivalue::PyObjectHolder>(); |
139 | } |
140 | inline c10::intrusive_ptr<ivalue::PyObjectHolder> IValue::toPyObjectHolder() |
141 | const& { |
142 | TORCH_INTERNAL_ASSERT(isPyObject(), "Expected PyObject but got " , tagKind()); |
143 | return toIntrusivePtr<ivalue::PyObjectHolder>(); |
144 | } |
145 | inline c10::intrusive_ptr<ivalue::EnumHolder> IValue::toEnumHolder() && { |
146 | TORCH_INTERNAL_ASSERT(isEnum(), "Expected Enum but got " , tagKind()); |
147 | return moveToIntrusivePtr<ivalue::EnumHolder>(); |
148 | } |
149 | inline c10::intrusive_ptr<ivalue::EnumHolder> IValue::toEnumHolder() const& { |
150 | TORCH_INTERNAL_ASSERT(isEnum(), "Expected Enum but got " , tagKind()); |
151 | return toIntrusivePtr<ivalue::EnumHolder>(); |
152 | } |
153 | inline c10::complex<double> IValue::toComplexDouble() const { |
154 | TORCH_INTERNAL_ASSERT(isComplexDouble(), "Expected ComplexDouble but got " , tagKind()); |
155 | auto ptr = toIntrusivePtr<ivalue::ComplexHolder>(); |
156 | return (*ptr).val; |
157 | } |
158 | inline at::Tensor IValue::toTensor() && { |
159 | if (C10_UNLIKELY(!isTensor())) { |
160 | reportToTensorTypeError(); |
161 | } |
162 | auto result = std::move(payload.as_tensor); |
163 | // As far as I can tell, omitting the usual explicit destructor call |
164 | // is not UB in and of itself, and it's a slight perf win. The |
165 | // destructor is a no-op, because the moved-from Tensor is |
166 | // effectively an intrusive_ptr in the null state, so we don't need |
167 | // the behavior for correctness reasons either. Leaving this |
168 | // explanatory comment, including commented-out destructor call, to |
169 | // make this abundantly clear. |
170 | // |
171 | // payload.as_tensor.~Tensor(); |
172 | clearToNone(); |
173 | return result; |
174 | } |
175 | inline at::Tensor& IValue::toTensor() & { |
176 | if (C10_UNLIKELY(!isTensor())) { |
177 | reportToTensorTypeError(); |
178 | } |
179 | return payload.as_tensor; |
180 | } |
181 | inline const at::Tensor& IValue::toTensor() const& { |
182 | if (C10_UNLIKELY(!isTensor())) { |
183 | reportToTensorTypeError(); |
184 | } |
185 | return payload.as_tensor; |
186 | } |
187 | inline c10::Storage IValue::toStorage() && { |
188 | AT_ASSERT(isStorage(), "Expected Storage but got " , tagKind()); |
189 | return c10::Storage( |
190 | moveToIntrusivePtr<at::StorageImpl>()); |
191 | } |
192 | inline c10::Storage IValue::toStorage() const& { |
193 | AT_ASSERT(isStorage(), "Expected Storage but got " , tagKind()); |
194 | return c10::Storage(toIntrusivePtr<at::StorageImpl>()); |
195 | } |
196 | inline c10::Stream IValue::toStream() && { |
197 | AT_ASSERT(isStream(), "Expected Stream but got " , tagKind()); |
198 | auto ptr = toIntrusivePtr<ivalue::StreamData3Holder>(); |
199 | return c10::Stream::unpack3((*ptr).val.stream_id, |
200 | (*ptr).val.device_index, |
201 | (*ptr).val.device_type); |
202 | } |
203 | inline c10::Stream IValue::toStream() const& { |
204 | AT_ASSERT(isStream(), "Expected Stream but got " , tagKind()); |
205 | auto ptr = toIntrusivePtr<ivalue::StreamData3Holder>(); |
206 | return c10::Stream::unpack3((*ptr).val.stream_id, |
207 | (*ptr).val.device_index, |
208 | (*ptr).val.device_type); |
209 | } |
210 | inline c10::intrusive_ptr<caffe2::Blob> IValue::toBlob() && { |
211 | AT_ASSERT(isBlob(), "Expected Blob but got " , tagKind()); |
212 | return moveToIntrusivePtr<caffe2::Blob>(); |
213 | } |
214 | inline c10::intrusive_ptr<caffe2::Blob> IValue::toBlob() const& { |
215 | AT_ASSERT(isBlob(), "Expected Blob but got " , tagKind()); |
216 | return toIntrusivePtr<caffe2::Blob>(); |
217 | ; |
218 | } |
219 | inline c10::intrusive_ptr<torch::CustomClassHolder> IValue::toCapsule() && { |
220 | TORCH_INTERNAL_ASSERT(isCapsule()); |
221 | return moveToIntrusivePtr<torch::CustomClassHolder>(); |
222 | } |
223 | inline c10::intrusive_ptr<torch::CustomClassHolder> IValue::toCapsule() const& { |
224 | TORCH_INTERNAL_ASSERT(isCapsule()); |
225 | return toIntrusivePtr<torch::CustomClassHolder>(); |
226 | } |
227 | inline at::Generator IValue::toGenerator() && { |
228 | AT_ASSERT(isGenerator(), "Expected Generator but got " , tagKind()); |
229 | return at::Generator(moveToIntrusivePtr<at::GeneratorImpl>()); |
230 | } |
231 | inline at::Generator IValue::toGenerator() const& { |
232 | AT_ASSERT(isGenerator(), "Expected Generator but got " , tagKind()); |
233 | return at::Generator(toIntrusivePtr<at::GeneratorImpl>()); |
234 | } |
235 | inline c10::SymInt IValue::toSymInt() && { |
236 | AT_ASSERT(isSymInt() || isInt(), "Expected SymInt or int but got " , tagKind()); |
237 | if (isSymInt()) { |
238 | return c10::SymInt(moveToIntrusivePtr<c10::SymNodeImpl>()); |
239 | } else { |
240 | return c10::SymInt(payload.u.as_int); |
241 | } |
242 | } |
243 | inline c10::SymInt IValue::toSymInt() const& { |
244 | AT_ASSERT(isSymInt() || isInt(), "Expected SymInt or int but got " , tagKind()); |
245 | if (isSymInt()) { |
246 | return c10::SymInt(toIntrusivePtr<c10::SymNodeImpl>()); |
247 | } else { |
248 | return c10::SymInt(payload.u.as_int); |
249 | } |
250 | } |
251 | inline c10::SymFloat IValue::toSymFloat() && { |
252 | AT_ASSERT(isSymFloat() || isDouble(), "Expected SymFloat or double but got " , tagKind()); |
253 | if (isSymFloat()) { |
254 | return c10::SymFloat(moveToIntrusivePtr<c10::SymNodeImpl>()); |
255 | } else { |
256 | return c10::SymFloat(payload.u.as_double); |
257 | } |
258 | } |
259 | inline c10::SymFloat IValue::toSymFloat() const& { |
260 | AT_ASSERT(isSymFloat() || isDouble(), "Expected SymFloat or double but got " , tagKind()); |
261 | if (isSymFloat()) { |
262 | return c10::SymFloat(toIntrusivePtr<c10::SymNodeImpl>()); |
263 | } else { |
264 | return c10::SymFloat(payload.u.as_double); |
265 | } |
266 | } |
267 | |
268 | namespace ivalue { |
269 | |
270 | void TORCH_API |
271 | checkCustomClassType(const ClassType* expected_type, const Type* actual_type); |
272 | |
273 | template <typename T> |
274 | using Shared = c10::intrusive_ptr<T>; |
275 | |
276 | // string |
277 | struct TORCH_API ConstantString final : c10::intrusive_ptr_target { |
278 | private: |
279 | const std::string str_; |
280 | |
281 | public: |
282 | ConstantString(std::string str) : str_(std::move(str)) {} |
283 | ConstantString(c10::string_view str) : str_(std::string(str)) {} |
284 | static c10::intrusive_ptr<ConstantString> create(std::string str_); |
285 | static c10::intrusive_ptr<ConstantString> create(c10::string_view str_); |
286 | static c10::intrusive_ptr<ConstantString> create(const char* str_); |
287 | |
288 | const std::string& string() const { |
289 | return str_; |
290 | } |
291 | c10::string_view string_view() const { |
292 | return str_; |
293 | } |
294 | |
295 | operator const std::string&() const { |
296 | return string(); |
297 | } |
298 | TORCH_API friend std::ostream& operator<<( |
299 | std::ostream& out, |
300 | const ConstantString& v); |
301 | }; |
302 | |
303 | struct Future; |
304 | |
305 | struct TORCH_API TupleElements { |
306 | private: |
307 | size_t inlineSize_; |
308 | // We represent TupleElements this way to save doing a heap |
309 | // allocation in the common (at least for unpickling) case where we |
310 | // have only 3 elements. We have our own union instead of |
311 | // c10::SmallVector<IValue> because c10::SmallVector<IValue> always |
312 | // stores the begin/end/capacity pointers, which would be a waste of |
313 | // space in our use case. |
314 | union { |
315 | std::vector<IValue> elementsVector_; |
316 | // Don't want to declare a std::array because the convenient |
317 | // iteration and size members are a footgun in this case -- the |
318 | // actual size of the array may be smaller than 3! |
319 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays) |
320 | IValue elementsInline_[3]; |
321 | }; |
322 | |
323 | void destroyInline() { |
324 | for (const auto ii : c10::irange(inlineSize_)) { |
325 | elementsInline_[ii].~IValue(); |
326 | } |
327 | } |
328 | public: |
329 | |
330 | using iterator = IValue*; |
331 | using const_iterator = const IValue*; |
332 | |
333 | TupleElements() : inlineSize_(0) { |
334 | new (&elementsVector_) std::vector<IValue>(); |
335 | } |
336 | |
337 | explicit TupleElements(std::vector<IValue> elements) |
338 | : inlineSize_(0), elementsVector_(std::move(elements)) {} |
339 | |
340 | explicit TupleElements(c10::ArrayRef<IValue> elements) |
341 | : inlineSize_(elements.size() <= 3 ? elements.size() : 0) { |
342 | switch (inlineSize_) { |
343 | case 3: |
344 | new (&elementsInline_[2]) IValue(elements[2]); |
345 | C10_FALLTHROUGH; |
346 | case 2: |
347 | new (&elementsInline_[1]) IValue(elements[1]); |
348 | C10_FALLTHROUGH; |
349 | case 1: |
350 | new (&elementsInline_[0]) IValue(elements[0]); |
351 | break; |
352 | case 0: |
353 | new (&elementsVector_) std::vector<IValue>(elements.begin(), elements.end()); |
354 | break; |
355 | } |
356 | } |
357 | |
358 | explicit TupleElements(IValue&& e1) |
359 | : inlineSize_(1) { |
360 | new (&elementsInline_[0]) IValue(std::move(e1)); |
361 | } |
362 | |
363 | explicit TupleElements(IValue&& e1, IValue&& e2) |
364 | : inlineSize_(2) { |
365 | new (&elementsInline_[0]) IValue(std::move(e1)); |
366 | new (&elementsInline_[1]) IValue(std::move(e2)); |
367 | } |
368 | |
369 | explicit TupleElements(IValue&& e1, IValue&& e2, IValue&& e3) |
370 | : inlineSize_(3) { |
371 | new (&elementsInline_[0]) IValue(std::move(e1)); |
372 | new (&elementsInline_[1]) IValue(std::move(e2)); |
373 | new (&elementsInline_[2]) IValue(std::move(e3)); |
374 | } |
375 | |
376 | ~TupleElements() { |
377 | if (inlineSize_) { |
378 | destroyInline(); |
379 | } else { |
380 | elementsVector_.~vector(); |
381 | } |
382 | } |
383 | |
384 | // It would be nice to make this noncopyable to prevent people from |
385 | // writing code like `auto output = |
386 | // forward(...).toTupleRef().elements()` (which does refcount bumps on |
387 | // each element, unlike the more efficient but verbose |
388 | // ``` |
389 | // auto outputIntrusivePtr = forward(...).toTuple(); |
390 | // const auto& output = outputIntrusivePtr->elements(); |
391 | // ``` |
392 | // ), but there is simply an overwhelming amount of code that does |
393 | // it the inefficient way. |
394 | // See also operator std::vector below. |
395 | TupleElements(const TupleElements& rhs) |
396 | : inlineSize_(rhs.inlineSize_) { |
397 | if (rhs.inlineSize_) { |
398 | for (const auto ii : c10::irange(inlineSize_)) { |
399 | new (&elementsInline_[ii]) IValue(rhs.elementsInline_[ii]); |
400 | } |
401 | } else { |
402 | new (&elementsVector_) std::vector<IValue>(rhs.elementsVector_); |
403 | } |
404 | } |
405 | |
406 | TupleElements& operator=(const TupleElements& rhs) { |
407 | if (inlineSize_) { |
408 | if (rhs.inlineSize_) { |
409 | for (const auto ii : c10::irange(std::min(inlineSize_, rhs.inlineSize_))) { |
410 | elementsInline_[ii] = rhs.elementsInline_[ii]; |
411 | } |
412 | if (rhs.inlineSize_ > inlineSize_) { |
413 | for (const auto ii : c10::irange(inlineSize_, rhs.inlineSize_)) { |
414 | new (&elementsInline_[ii]) IValue(rhs.elementsInline_[ii]); |
415 | } |
416 | } else { |
417 | for (const auto ii : c10::irange(rhs.inlineSize_, inlineSize_)) { |
418 | elementsInline_[ii].~IValue(); |
419 | } |
420 | } |
421 | } else { |
422 | destroyInline(); |
423 | new (&elementsVector_) std::vector<IValue>(rhs.elementsVector_); |
424 | } |
425 | } else { |
426 | if (rhs.inlineSize_) { |
427 | elementsVector_.~vector(); |
428 | for (const auto ii : c10::irange(rhs.inlineSize_)) { |
429 | new (&elementsInline_[ii]) IValue(rhs.elementsInline_[ii]); |
430 | } |
431 | } else { |
432 | elementsVector_ = rhs.elementsVector_; |
433 | } |
434 | } |
435 | inlineSize_ = rhs.inlineSize_; |
436 | return *this; |
437 | } |
438 | |
439 | TupleElements(TupleElements&& rhs) noexcept |
440 | : inlineSize_(rhs.inlineSize_) { |
441 | if (inlineSize_) { |
442 | for (const auto ii : c10::irange(inlineSize_)) { |
443 | new (&elementsInline_[ii]) IValue(std::move(rhs.elementsInline_[ii])); |
444 | } |
445 | } else { |
446 | new (&elementsVector_) std::vector<IValue>(std::move(rhs.elementsVector_)); |
447 | } |
448 | } |
449 | |
450 | TupleElements& operator=(TupleElements&& rhs) noexcept { |
451 | if (inlineSize_) { |
452 | if (rhs.inlineSize_) { |
453 | for (const auto ii : c10::irange(std::min(inlineSize_, rhs.inlineSize_))) { |
454 | elementsInline_[ii] = std::move(rhs.elementsInline_[ii]); |
455 | } |
456 | if (rhs.inlineSize_ > inlineSize_) { |
457 | for (const auto ii : c10::irange(inlineSize_, rhs.inlineSize_)) { |
458 | new (&elementsInline_[ii]) IValue(std::move(rhs.elementsInline_[ii])); |
459 | } |
460 | } else { |
461 | for (const auto ii : c10::irange(rhs.inlineSize_, inlineSize_)) { |
462 | elementsInline_[ii].~IValue(); |
463 | } |
464 | } |
465 | } else { |
466 | destroyInline(); |
467 | new (&elementsVector_) std::vector<IValue>(std::move(rhs.elementsVector_)); |
468 | } |
469 | } else { |
470 | if (rhs.inlineSize_) { |
471 | elementsVector_.~vector(); |
472 | for (const auto ii : c10::irange(rhs.inlineSize_)) { |
473 | new (&elementsInline_[ii]) IValue(std::move(rhs.elementsInline_[ii])); |
474 | } |
475 | } else { |
476 | elementsVector_ = std::move(rhs.elementsVector_); |
477 | } |
478 | } |
479 | inlineSize_ = rhs.inlineSize_; |
480 | return *this; |
481 | } |
482 | |
483 | C10_NODISCARD c10::ArrayRef<IValue> asArrayRef() const { |
484 | if (inlineSize_) { |
485 | return c10::ArrayRef<IValue>(elementsInline_, inlineSize_); |
486 | } else { |
487 | return elementsVector_; |
488 | } |
489 | } |
490 | |
491 | // Mimic implicit conversion from std::vector to ArrayRef. |
492 | operator c10::ArrayRef<IValue>() const { |
493 | return asArrayRef(); |
494 | } |
495 | |
496 | static size_t hash(const TupleElements& v) { |
497 | return c10::hash<c10::ArrayRef<IValue>>()(v.asArrayRef()); |
498 | } |
499 | |
500 | void setContents(std::vector<IValue>&& contents) { |
501 | if (inlineSize_) { |
502 | destroyInline(); |
503 | new (&elementsVector_) std::vector<IValue>(std::move(contents)); |
504 | inlineSize_ = 0; |
505 | } else { |
506 | elementsVector_ = std::move(contents); |
507 | } |
508 | } |
509 | |
510 | C10_NODISCARD bool empty() const { |
511 | return inlineSize_ ? false : elementsVector_.empty(); |
512 | } |
513 | |
514 | C10_NODISCARD size_t size() const { |
515 | return inlineSize_ ? inlineSize_ : elementsVector_.size(); |
516 | } |
517 | |
518 | C10_NODISCARD IValue& operator[](size_t idx) { |
519 | if (inlineSize_) { |
520 | return elementsInline_[idx]; |
521 | } else { |
522 | return elementsVector_[idx]; |
523 | } |
524 | } |
525 | |
526 | C10_NODISCARD const IValue& operator[](size_t idx) const { |
527 | if (inlineSize_) { |
528 | return elementsInline_[idx]; |
529 | } else { |
530 | return elementsVector_[idx]; |
531 | } |
532 | } |
533 | |
534 | C10_NODISCARD IValue& at(size_t idx) { |
535 | if (inlineSize_) { |
536 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inlineSize_ <= 3); |
537 | TORCH_CHECK(idx < inlineSize_, "TupleElements: invalid index Index = " , idx, "; Length = " , inlineSize_); |
538 | return elementsInline_[idx]; |
539 | } else { |
540 | return elementsVector_.at(idx); |
541 | } |
542 | } |
543 | |
544 | C10_NODISCARD const IValue& at(size_t idx) const { |
545 | if (inlineSize_) { |
546 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inlineSize_ <= 3); |
547 | TORCH_CHECK(idx < inlineSize_, "TupleElements: invalid index Index = " , idx, "; Length = " , inlineSize_); |
548 | return elementsInline_[idx]; |
549 | } else { |
550 | TORCH_CHECK(idx < elementsVector_.size(), "TupleElements: invalid index Index = " , idx, "; Length = " , elementsVector_.size()); |
551 | return elementsVector_.at(idx); |
552 | } |
553 | } |
554 | |
555 | C10_NODISCARD iterator begin() { |
556 | if (inlineSize_) { |
557 | return elementsInline_; |
558 | } else { |
559 | return elementsVector_.data(); |
560 | } |
561 | } |
562 | |
563 | C10_NODISCARD iterator end() { |
564 | if (inlineSize_) { |
565 | return elementsInline_ + inlineSize_; |
566 | } else { |
567 | return elementsVector_.data() + elementsVector_.size(); |
568 | } |
569 | } |
570 | |
571 | C10_NODISCARD const_iterator begin() const { |
572 | if (inlineSize_) { |
573 | return elementsInline_; |
574 | } else { |
575 | return elementsVector_.data(); |
576 | } |
577 | } |
578 | |
579 | C10_NODISCARD const_iterator end() const { |
580 | if (inlineSize_) { |
581 | return elementsInline_ + inlineSize_; |
582 | } else { |
583 | return elementsVector_.data() + elementsVector_.size(); |
584 | } |
585 | } |
586 | |
587 | C10_NODISCARD const_iterator cbegin() const { |
588 | return begin(); |
589 | } |
590 | |
591 | C10_NODISCARD const_iterator cend() const { |
592 | return end(); |
593 | } |
594 | |
595 | C10_NODISCARD std::vector<IValue> vec() const & { |
596 | return asArrayRef().vec(); |
597 | } |
598 | |
599 | C10_NODISCARD IValue& back() { |
600 | return *(end() - 1); |
601 | } |
602 | |
603 | C10_NODISCARD const IValue& back() const { |
604 | return *(end() - 1); |
605 | } |
606 | |
607 | C10_NODISCARD std::vector<IValue> vec() && { |
608 | std::vector<IValue> result; |
609 | result.reserve(size()); |
610 | for (auto&& iv : *this) { |
611 | result.push_back(std::move(iv)); |
612 | } |
613 | return result; |
614 | } |
615 | |
616 | // More compatibility shims for the overwhelming amount of code that |
617 | // likes to copy tuple elements into a vector; see comment above the |
618 | // copy constructor. |
619 | operator std::vector<IValue>() const & { |
620 | return vec(); |
621 | } |
622 | |
623 | operator std::vector<IValue>() && { |
624 | return vec(); |
625 | } |
626 | }; |
627 | |
628 | template <typename T> |
629 | struct TupleTypeFactory {}; |
630 | |
631 | template <> |
632 | struct TORCH_API TupleTypeFactory<TupleType> { |
633 | static TupleTypePtr create(std::vector<TypePtr> types) { |
634 | return TupleType::create(std::move(types)); |
635 | } |
636 | static TupleTypePtr fallback(const Type& type); |
637 | }; |
638 | |
639 | template <> |
640 | struct TORCH_API TupleTypeFactory<c10::DynamicType> { |
641 | static DynamicTypePtr create(std::vector<TypePtr> elemTypes); |
642 | static DynamicTypePtr fallback(const Type&); |
643 | }; |
644 | |
645 | struct TORCH_API Tuple : c10::intrusive_ptr_target { |
646 | private: |
647 | TupleElements elements_; |
648 | mutable c10::TypePtr type_; // lazily computed for unnamed tuples |
649 | |
650 | public: |
651 | // named tuples have additional type information, so we |
652 | // directly create them tagged |
653 | static c10::intrusive_ptr<Tuple> createNamed( |
654 | std::vector<IValue> elements_, |
655 | c10::TypePtr type_) { |
656 | return c10::make_intrusive<Tuple>(std::move(elements_), std::move(type_)); |
657 | } |
658 | |
659 | static c10::intrusive_ptr<Tuple> createNamed( |
660 | TupleElements elements_, |
661 | std::shared_ptr<TupleType> type_) { |
662 | return c10::make_intrusive<Tuple>(std::move(elements_), std::move(type_)); |
663 | } |
664 | |
665 | static c10::intrusive_ptr<Tuple> createNamed( |
666 | std::initializer_list<IValue> elements_, |
667 | std::shared_ptr<TupleType> type_) { |
668 | return createNamed(TupleElements(c10::ArrayRef<IValue>(elements_)), std::move(type_)); |
669 | } |
670 | |
671 | // MSVC apparently can't disambiguate the other two overloads of |
672 | // create when passed an initializer_list without this. |
673 | static c10::intrusive_ptr<Tuple> create(std::initializer_list<IValue> elements_) { |
674 | return create(c10::ArrayRef<IValue>(elements_)); |
675 | } |
676 | |
677 | static c10::intrusive_ptr<Tuple> create(std::vector<IValue> elements_) { |
678 | return c10::make_intrusive<Tuple>(std::move(elements_)); |
679 | } |
680 | |
681 | static c10::intrusive_ptr<Tuple> create(TupleElements elements_) { |
682 | return c10::make_intrusive<Tuple>(std::move(elements_)); |
683 | } |
684 | |
685 | static c10::intrusive_ptr<Tuple> create(c10::ArrayRef<IValue> elements_) { |
686 | return create(TupleElements(elements_)); |
687 | } |
688 | |
689 | static c10::intrusive_ptr<Tuple> create(IValue e1) { |
690 | return c10::make_intrusive<Tuple>(std::move(e1)); |
691 | } |
692 | |
693 | static c10::intrusive_ptr<Tuple> create(IValue e1, IValue e2) { |
694 | return c10::make_intrusive<Tuple>(std::move(e1), std::move(e2)); |
695 | } |
696 | |
697 | static c10::intrusive_ptr<Tuple> create(IValue e1, IValue e2, IValue e3) { |
698 | return c10::make_intrusive<Tuple>(std::move(e1), std::move(e2), std::move(e3)); |
699 | } |
700 | |
701 | private: |
702 | // Workaround inability to use `>` operator in template argument list. |
703 | template <typename... Args> |
704 | static constexpr bool hasMoreThanThreeArgs() { |
705 | return sizeof...(Args) > 3; |
706 | } |
707 | |
708 | public: |
709 | template <typename... Args> |
710 | static c10::intrusive_ptr<Tuple> create(Args&&... elements_) { |
711 | switch (sizeof...(Args)) { |
712 | case 1: |
713 | case 2: |
714 | case 3: |
715 | return create(IValue(std::forward<Args>(elements_))...); |
716 | default: |
717 | return create( |
718 | std::vector<IValue>{IValue(std::forward<Args>(elements_))...}); |
719 | } |
720 | } |
721 | |
722 | // Again, it would be nice to make this noncopyable, but there's a |
723 | // lot of extant code that copies Tuples. |
724 | // Tuple(const Tuple& rhs) = delete; |
725 | |
726 | const TupleElements& elements() const& { |
727 | return elements_; |
728 | } |
729 | |
730 | TupleElements elements() && { |
731 | return std::move(elements_); |
732 | } |
733 | |
734 | void setElements(std::vector<IValue>&& elements) { |
735 | elements_.setContents(std::move(elements)); |
736 | } |
737 | |
738 | void setElements(TupleElements&& elements) { |
739 | elements_ = std::move(elements); |
740 | } |
741 | |
742 | void unsafeSetElement(size_t idx, const IValue& element) { |
743 | elements_[idx] = element; |
744 | } |
745 | |
746 | void unsafeSetElement(size_t idx, IValue&& element) { |
747 | elements_[idx] = std::move(element); |
748 | } |
749 | |
750 | size_t size() const { |
751 | return elements_.size(); |
752 | } |
753 | |
754 | template <typename T = c10::TupleType> |
755 | std::shared_ptr<T> type() const { |
756 | if (!type_) { |
757 | type_ = TupleTypeFactory<T>::create(fmap(elements(), [&](const IValue& v) { |
758 | return v.type<typename T::ElementType>(); |
759 | })); |
760 | } |
761 | if (auto t = type_->cast<T>()) { |
762 | return t; |
763 | } |
764 | return TupleTypeFactory<T>::fallback(*type_); |
765 | } |
766 | |
767 | static size_t hash(const Tuple& t) { |
768 | return c10::get_hash(t.elements()); |
769 | } |
770 | |
771 | TORCH_API friend bool operator==( |
772 | const ivalue::Tuple& lhs, |
773 | const ivalue::Tuple& rhs); |
774 | |
775 | private: |
776 | // NOTE: If we try to avoid the overloads without |
777 | // `std::shared_ptr<TupleType> type` by defaulting it to nullptr, we |
778 | // end up having to call (part of) the shared_ptr destructor for |
779 | // `type` even though we should know statically it won't do |
780 | // anything. |
781 | explicit Tuple(std::vector<IValue> elements) |
782 | : elements_(std::move(elements)){} |
783 | |
784 | explicit Tuple(std::vector<IValue> elements, c10::TypePtr type) |
785 | : elements_(std::move(elements)), type_(std::move(type)) {} |
786 | |
787 | explicit Tuple(TupleElements&& elements) |
788 | : elements_(std::move(elements)) {} |
789 | |
790 | explicit Tuple(TupleElements&& elements, std::shared_ptr<TupleType> type) |
791 | : elements_(std::move(elements)), type_(std::move(type)) {} |
792 | |
793 | explicit Tuple(IValue&& e1) |
794 | : elements_(std::move(e1)) {} |
795 | |
796 | explicit Tuple(IValue&& e1, std::shared_ptr<TupleType> type) |
797 | : elements_(std::move(e1)), type_(std::move(type)) {} |
798 | |
799 | explicit Tuple(IValue&& e1, IValue&& e2) |
800 | : elements_(std::move(e1), std::move(e2)) {} |
801 | |
802 | explicit Tuple(IValue&& e1, IValue&& e2, std::shared_ptr<TupleType> type) |
803 | : elements_(std::move(e1), std::move(e2)), type_(std::move(type)) {} |
804 | |
805 | explicit Tuple(IValue&& e1, IValue&& e2, IValue&& e3) |
806 | : elements_(std::move(e1), std::move(e2), std::move(e3)) {} |
807 | |
808 | explicit Tuple(IValue&& e1, IValue&& e2, IValue&& e3, std::shared_ptr<TupleType> type) |
809 | : elements_(std::move(e1), std::move(e2), std::move(e3)), type_(std::move(type)) {} |
810 | |
811 | friend class c10::intrusive_ptr<Tuple>; |
812 | }; |
813 | |
814 | struct Object; |
815 | struct PyObjectHolder; |
816 | struct EnumHolder; |
817 | } // namespace ivalue |
818 | |
819 | // Future |
820 | struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target { |
821 | private: |
822 | // Keep this private in order to force users to go through make_intrusive and |
823 | // thus prevent creating a Future that's not held by an intrusive_ptr. |
824 | explicit Future(TypePtr type, std::vector<c10::Device> devices={}) |
825 | : type_(std::move(type)), |
826 | impl_(getTypeOfDevices(devices)), |
827 | devices_(sortAndDeduplicateDevices(impl_, std::move(devices))) {} |
828 | |
829 | friend c10::intrusive_ptr<Future>; |
830 | |
831 | public: |
832 | Future(const Future&) = delete; |
833 | Future(Future&&) = delete; |
834 | Future& operator=(const Future&) = delete; |
835 | Future& operator=(Future&&) = delete; |
836 | |
837 | struct TORCH_API FutureError final : public std::exception { |
838 | explicit FutureError(std::string&& error_msg_) |
839 | : error_msg(std::move(error_msg_)) {} |
840 | |
841 | FutureError() = default; |
842 | |
843 | const char* what() const noexcept override { |
844 | return error_msg.c_str(); |
845 | } |
846 | |
847 | std::string error_msg; |
848 | }; |
849 | |
850 | /** |
851 | * Wait on the future until it completes. |
852 | */ |
853 | void wait() { |
854 | std::unique_lock<std::mutex> lock(mutex_); |
855 | finished_cv_.wait(lock, [&]() -> bool { return completed_; }); |
856 | synchronizeWithCurrentStreams(); |
857 | } |
858 | |
859 | /** |
860 | * Wait on the future until it completes and throw an |
861 | * exception if an error exists. |
862 | */ |
863 | void waitAndThrow() { |
864 | wait(); |
865 | |
866 | if (eptr_) { |
867 | std::rethrow_exception(eptr_); |
868 | } |
869 | } |
870 | |
871 | /** |
872 | * Explicitly mark the future as completed with the output value. Optionally, |
873 | * the storages for all tensors in IValue can be passed as well. The DataPtrs |
874 | * of these storages are used to synchronize CUDA streams. If storages isn't |
875 | * given we will attempt to extract it from the value, if we need to (this |
876 | * happens if a non-empty set of devices was given to the constructor). Thus |
877 | * one only needs to provide storages when 1) they cannot be extracted through |
878 | * IValue::getSubValues() or through pickling in case of Python object; or |
879 | * when 2) customized storage extraction is more efficient. |
880 | */ |
881 | using WeakStorage = c10::weak_intrusive_ptr<c10::StorageImpl>; |
882 | void markCompleted( |
883 | IValue value, |
884 | c10::optional<std::vector<WeakStorage>> storages = c10::nullopt) { |
885 | // Start by performing all steps that can throw, before setting any field. |
886 | // Do this before even acquiring the mutex, because extractStorages might |
887 | // acquire the GIL, which could lead to a lock inversion with our mutex. |
888 | // See https://github.com/pytorch/pytorch/issues/58239. |
889 | std::vector<WeakStorage> actualStorages; |
890 | std::vector<c10::Device> usedDevices; |
891 | try { |
892 | // FIXME We should always extract DataPtrs, in order to catch the case of |
893 | // users using CUDA values but forgetting to set devices, which currently |
894 | // leads to a silent synchronization/correctness issue. However, as this |
895 | // might worsen perf in CPU-only cases, we should only do so after careful |
896 | // benchmarks. |
897 | if (impl_.type() != c10::kCPU) { |
898 | actualStorages = |
899 | storages.has_value() ? std::move(*storages) : extractStorages(value); |
900 | usedDevices = getDevicesOfStorages(impl_, actualStorages); |
901 | ensureIsSubsetOfDevices(usedDevices, devices_); |
902 | } |
903 | } catch (const std::exception&) { |
904 | setError(std::current_exception()); |
905 | return; |
906 | } |
907 | |
908 | std::unique_lock<std::mutex> lock(mutex_); |
909 | TORCH_CHECK( |
910 | !completed(), |
911 | "Attempting to mark a completed Future as complete again. Note that " |
912 | "a Future can only be marked completed once." ); |
913 | |
914 | // Only set value_ and completed_ flag once all checks and preparation steps |
915 | // have returned successfully to allow for proper error propagation. |
916 | value_ = std::move(value); |
917 | completed_ = true; |
918 | |
919 | currentDevice_ = impl_.getDevice(); |
920 | storages_ = std::move(actualStorages); |
921 | for (const c10::Device& device : usedDevices) { |
922 | c10::Event event(impl_.type()); |
923 | event.record(impl_.getStream(device)); |
924 | events_.push_back(std::move(event)); |
925 | } |
926 | |
927 | std::vector<std::function<void(Future&)>> cbs; |
928 | cbs.swap(callbacks_); |
929 | lock.unlock(); |
930 | |
931 | finished_cv_.notify_all(); |
932 | for (auto& callback : cbs) { |
933 | invokeCallback(std::move(callback)); |
934 | } |
935 | } |
936 | |
937 | void markCompleted() { |
938 | markCompleted(IValue{}); |
939 | } |
940 | |
941 | void setError(std::exception_ptr eptr) { |
942 | std::unique_lock<std::mutex> lock(mutex_); |
943 | setErrorInternal(std::move(eptr), lock); |
944 | } |
945 | |
946 | void setErrorIfNeeded(std::exception_ptr eptr) { |
947 | std::unique_lock<std::mutex> lock(mutex_); |
948 | if (completed_) { |
949 | // This should be rare and shouldn't cause log spew. Its important to |
950 | // log errors and thats why we have this log here. |
951 | std::string msg = c10::str( |
952 | "Skipping setting following error on the Future since " |
953 | "it is already marked completed (this is not necessarily " |
954 | "an error):\n" , |
955 | tryRetrieveErrorMessageInternal(std::move(eptr))); |
956 | if (eptr_) { |
957 | msg += c10::str( |
958 | ", \nOriginal exception:\n" , |
959 | tryRetrieveErrorMessageInternal(eptr_)); |
960 | } |
961 | LOG(INFO) << msg; |
962 | return; |
963 | } else { |
964 | setErrorInternal(std::move(eptr), lock); |
965 | } |
966 | } |
967 | |
968 | // Get the result of the current future. |
969 | IValue value() { |
970 | std::unique_lock<std::mutex> lock(mutex_); |
971 | AT_ASSERT(completed()); |
972 | if (eptr_) { |
973 | std::rethrow_exception(eptr_); |
974 | } |
975 | return value_; |
976 | } |
977 | |
978 | // This accessor should only be used if we know that the future is |
979 | // completed() with no error. |
980 | const IValue& constValue() const { |
981 | std::unique_lock<std::mutex> lock(mutex_); |
982 | AT_ASSERT(completed()); |
983 | TORCH_INTERNAL_ASSERT( |
984 | !eptr_, |
985 | "value() accessor should only be used when future is not completed with " , |
986 | "an error, but future had the following error: " , |
987 | tryRetrieveErrorMessageInternal(eptr_) |
988 | ); |
989 | return value_; |
990 | } |
991 | |
992 | // This accessor should only be used if we know that the future is |
993 | // completed() with no error. |
994 | const std::vector<WeakStorage>& storages() const { |
995 | std::unique_lock<std::mutex> lock(mutex_); |
996 | AT_ASSERT(completed()); |
997 | AT_ASSERT(!eptr_); |
998 | return storages_; |
999 | } |
1000 | |
1001 | /** |
1002 | * Add a callback to the future. |
1003 | * The callbacks will be executed once the future completes. |
1004 | * If the future has already completed, |
1005 | * this function will execute the callback immediately. |
1006 | */ |
1007 | template <typename T> |
1008 | void addCallback(T callback) { |
1009 | #if __cpp_lib_is_invocable >= 201703 |
1010 | static_assert( |
1011 | std::is_invocable_r<void, T, Future&>::value, |
1012 | "The callback must have signature void(Future&)" ); |
1013 | #endif |
1014 | std::unique_lock<std::mutex> lock(mutex_); |
1015 | if (completed()) { |
1016 | lock.unlock(); |
1017 | invokeCallback(std::move(callback)); |
1018 | return; |
1019 | } |
1020 | callbacks_.emplace_back(std::move(callback)); |
1021 | } |
1022 | |
1023 | /** |
1024 | * Add a callback to the future, and return another Future to hold the return |
1025 | * value of the callback. This is necessary when the callback provider needs |
1026 | * to know for sure when the callback has finished. |
1027 | */ |
1028 | template <typename T> |
1029 | c10::intrusive_ptr<Future> then(T callback, TypePtr type) { |
1030 | using IValueWithStorages = std::tuple<IValue, std::vector<WeakStorage>>; |
1031 | #if __cpp_lib_is_invocable >= 201703 |
1032 | static_assert( |
1033 | guts::disjunction< |
1034 | std::is_invocable_r<IValue, T, Future&>, |
1035 | std::is_invocable_r<IValueWithStorages, T, Future&>>::value, |
1036 | "The callback must have signature IValue(Future&) or " |
1037 | "std::tuple<IValue, std::vector<Storage>>(Future&)" ); |
1038 | #endif |
1039 | auto childFut = createInstance(std::move(type)); |
1040 | addCallback([childFut, |
1041 | cb = std::move(callback)](Future& parentFut) mutable { |
1042 | try { |
1043 | guts::if_constexpr<std::is_convertible< |
1044 | typename c10::invoke_result_t<T &&, Future&>, |
1045 | IValueWithStorages>::value>( |
1046 | [&](auto identity) { |
1047 | IValue value; |
1048 | std::vector<WeakStorage> storages; |
1049 | std::tie(value, storages) = identity(cb)(parentFut); |
1050 | childFut->markCompleted(std::move(value), std::move(storages)); |
1051 | }, |
1052 | [&](auto identity) { |
1053 | childFut->markCompleted(identity(cb)(parentFut)); |
1054 | }); |
1055 | } catch (std::exception&) { |
1056 | childFut->setError(std::current_exception()); |
1057 | } |
1058 | }); |
1059 | return childFut; |
1060 | } |
1061 | |
1062 | template <typename T> |
1063 | c10::intrusive_ptr<Future> thenAsync(T callback, TypePtr type) { |
1064 | #if __cpp_lib_is_invocable >= 201703 |
1065 | static_assert( |
1066 | std::is_invocable_r<c10::intrusive_ptr<Future>, T, Future&>::value, |
1067 | "The callback must have signature c10::intrusive_ptr<Future>(Future&)" ); |
1068 | #endif |
1069 | auto childFut = createInstance(std::move(type)); |
1070 | addCallback( |
1071 | [childFut, cb = std::move(callback)](Future& parentFut) mutable { |
1072 | c10::intrusive_ptr<Future> intermediateFut; |
1073 | try { |
1074 | intermediateFut = cb(parentFut); |
1075 | } catch (std::exception&) { |
1076 | childFut->setError(std::current_exception()); |
1077 | return; |
1078 | } |
1079 | intermediateFut->addCallback( |
1080 | [childFut = std::move(childFut)](Future& intermediateFut) { |
1081 | if (intermediateFut.hasError()) { |
1082 | childFut->setError(intermediateFut.exception_ptr()); |
1083 | } else { |
1084 | childFut->markCompleted( |
1085 | intermediateFut.value(), intermediateFut.storages()); |
1086 | } |
1087 | }); |
1088 | }); |
1089 | return childFut; |
1090 | } |
1091 | |
1092 | // Tries to retrieve the error message from std::exception_ptr. |
1093 | std::string tryRetrieveErrorMessage() const { |
1094 | TORCH_CHECK(hasError(), "No error present on the future." ); |
1095 | std::unique_lock<std::mutex> lock(mutex_); |
1096 | return tryRetrieveErrorMessageInternal(eptr_); |
1097 | } |
1098 | |
1099 | // Check if the current future has completed |
1100 | bool completed() const { |
1101 | return completed_; |
1102 | } |
1103 | |
1104 | bool hasValue() const { |
1105 | std::unique_lock<std::mutex> lock(mutex_); |
1106 | return completed_ && !eptr_; |
1107 | } |
1108 | |
1109 | bool hasError() const { |
1110 | std::unique_lock<std::mutex> lock(mutex_); |
1111 | return eptr_ ? true : false; |
1112 | } |
1113 | |
1114 | std::exception_ptr exception_ptr() const { |
1115 | std::unique_lock<std::mutex> lock(mutex_); |
1116 | return eptr_; |
1117 | } |
1118 | |
1119 | TORCH_API friend std::ostream& operator<<( |
1120 | std::ostream& out, |
1121 | const Future& v); |
1122 | |
1123 | TypePtr elementType() const { |
1124 | return type_; |
1125 | } |
1126 | |
1127 | const std::vector<c10::Device>& devices() const { |
1128 | return devices_; |
1129 | } |
1130 | |
1131 | // This method should be used when one intends to manually create a child |
1132 | // future, for example when implementing a customized version of then(). |
1133 | c10::intrusive_ptr<Future> createInstance(at::TypePtr type) { |
1134 | return c10::make_intrusive<Future>(std::move(type), devices_); |
1135 | } |
1136 | |
1137 | private: |
1138 | |
1139 | // This method should always be used when invoking a callback (regardless of |
1140 | // how/when that happens) as it will ensure that the proper "environment" is |
1141 | // set up before running the callback, as in, it will set up the CUDA streams, |
1142 | // synchronize them with the value, and so on (if needed). |
1143 | template<typename T> |
1144 | void invokeCallback(T callback) { |
1145 | #if __cpp_lib_is_invocable >= 201703 |
1146 | static_assert( |
1147 | std::is_invocable_r<void, T, Future&>::value, |
1148 | "The callback must have signature void(Future&)" ); |
1149 | #endif |
1150 | |
1151 | c10::OptionalDeviceGuard deviceGuard(currentDevice_); |
1152 | |
1153 | std::vector<c10::Stream> streams; |
1154 | streams.reserve(devices_.size()); |
1155 | for (const c10::Device& device : devices_) { |
1156 | streams.push_back(impl_.getStreamFromGlobalPool(device)); |
1157 | } |
1158 | c10::MultiStreamGuard streamGuard(streams); |
1159 | synchronizeWithCurrentStreams(); |
1160 | |
1161 | callback(*this); |
1162 | } |
1163 | |
1164 | // This method should be called before this future's value is used, as it |
1165 | // ensures that the CUDA streams that are "current" at the callsite properly |
1166 | // synchronize with the value. |
1167 | void synchronizeWithCurrentStreams() { |
1168 | for (c10::Event& event : events_) { |
1169 | event.block(impl_.getStream(event.device())); |
1170 | } |
1171 | |
1172 | for (const WeakStorage& weak_storage : storages_) { |
1173 | c10::intrusive_ptr<c10::StorageImpl> storage = weak_storage.lock(); |
1174 | if (!storage) { |
1175 | continue; |
1176 | } |
1177 | if (!storage->device().is_cpu()) { |
1178 | impl_.recordDataPtrOnStream( |
1179 | storage->data_ptr(), impl_.getStream(storage->device())); |
1180 | } |
1181 | } |
1182 | } |
1183 | |
1184 | void setErrorInternal( |
1185 | std::exception_ptr eptr, |
1186 | std::unique_lock<std::mutex>& lock) { |
1187 | TORCH_CHECK( |
1188 | !eptr_, |
1189 | "Error already set on this Future: " , |
1190 | tryRetrieveErrorMessageInternal(eptr_), |
1191 | ", trying to set error: " , |
1192 | tryRetrieveErrorMessageInternal(eptr)); |
1193 | TORCH_INTERNAL_ASSERT(!completed(), "Future is already marked completed" ); |
1194 | completed_ = true; |
1195 | eptr_ = std::move(eptr); |
1196 | |
1197 | std::vector<std::function<void(Future&)>> cbs; |
1198 | cbs.swap(callbacks_); |
1199 | lock.unlock(); |
1200 | |
1201 | finished_cv_.notify_all(); |
1202 | for (auto& callback : cbs) { |
1203 | invokeCallback(std::move(callback)); |
1204 | } |
1205 | } |
1206 | |
1207 | // Tries to retrieve the error message from std::exception_ptr. |
1208 | std::string tryRetrieveErrorMessageInternal(std::exception_ptr eptr) const { |
1209 | try { |
1210 | std::rethrow_exception(std::move(eptr)); |
1211 | } catch (const std::exception& e) { |
1212 | return e.what(); |
1213 | } catch (...) { |
1214 | return "Unknown Exception Type" ; |
1215 | } |
1216 | } |
1217 | |
1218 | // Defined in ivalue.cpp. |
1219 | static std::vector<WeakStorage> ( |
1220 | const at::IValue& value); |
1221 | |
1222 | static std::vector<c10::Device> getDevicesOfStorages( |
1223 | const c10::impl::VirtualGuardImpl& impl, |
1224 | const std::vector<WeakStorage>& storages) { |
1225 | c10::DeviceIndex deviceCount = impl.deviceCount(); |
1226 | std::vector<bool> isDeviceUsed(deviceCount, false); |
1227 | for (const WeakStorage& weak_storage : storages) { |
1228 | c10::intrusive_ptr<c10::StorageImpl> storage = weak_storage.lock(); |
1229 | if (!storage) { |
1230 | continue; |
1231 | } |
1232 | c10::Device device = storage->device(); |
1233 | if (!device.is_cpu()) { |
1234 | TORCH_CHECK_VALUE( |
1235 | device.type() == impl.type(), |
1236 | "Expected all data ptrs to be on a device of type " , |
1237 | impl.type(), |
1238 | ", got one on device " , |
1239 | device); |
1240 | isDeviceUsed[device.index()] = true; |
1241 | } |
1242 | } |
1243 | std::vector<c10::Device> devices; |
1244 | for (c10::DeviceIndex idx = 0; idx < deviceCount; idx++) { |
1245 | if (isDeviceUsed[idx]) { |
1246 | devices.emplace_back(impl.type(), idx); |
1247 | } |
1248 | } |
1249 | return devices; |
1250 | } |
1251 | |
1252 | static std::string formatSetOfDevices( |
1253 | const std::vector<c10::Device>& devices) { |
1254 | if (devices.empty()) { |
1255 | return "(none)" ; |
1256 | } |
1257 | std::ostringstream oss; |
1258 | oss << devices[0]; |
1259 | for (const auto idx : c10::irange(1, devices.size())) { |
1260 | if (idx == devices.size() - 1) { |
1261 | oss << " and " ; |
1262 | } else { |
1263 | oss << ", " ; |
1264 | } |
1265 | oss << devices[idx]; |
1266 | } |
1267 | return oss.str(); |
1268 | } |
1269 | |
1270 | static c10::DeviceType getTypeOfDevices( |
1271 | const std::vector<c10::Device>& devices) { |
1272 | if (devices.empty()) { |
1273 | return c10::kCPU; |
1274 | } |
1275 | c10::DeviceType deviceType = devices[0].type(); |
1276 | for (const auto idx : c10::irange(1, devices.size())) { |
1277 | TORCH_CHECK_VALUE( |
1278 | devices[idx].type() == deviceType, |
1279 | "Expected all devices to be of the same type, but got a mismatch between " , |
1280 | devices[0], |
1281 | " and " , |
1282 | devices[idx]); |
1283 | } |
1284 | return deviceType; |
1285 | } |
1286 | |
1287 | // We need devices to be sorted in order to use ensureIsSubsetOfDevices. |
1288 | static std::vector<c10::Device> sortAndDeduplicateDevices( |
1289 | const c10::impl::VirtualGuardImpl& /*impl*/, |
1290 | std::vector<c10::Device> devices) { |
1291 | std::sort( |
1292 | devices.begin(), devices.end(), |
1293 | [](const c10::Device& a, const c10::Device& b) { return a.index() < b.index(); }); |
1294 | // Deduplicate by compacting. |
1295 | size_t targetIdx = 0; |
1296 | for (const auto sourceIdx : c10::irange(devices.size())) { |
1297 | TORCH_CHECK_VALUE( |
1298 | devices[sourceIdx].has_index(), |
1299 | "Expected devices to have indices, got " , devices[sourceIdx]); |
1300 | if (targetIdx > 0 && devices[targetIdx - 1].index() == devices[sourceIdx].index()) { |
1301 | // It's a duplicate, skip it. |
1302 | continue; |
1303 | } |
1304 | if (sourceIdx != targetIdx) { |
1305 | devices[targetIdx] = devices[sourceIdx]; |
1306 | } |
1307 | targetIdx++; |
1308 | } |
1309 | // If there were duplicates there's now a gap at the end: trim it. Resizing |
1310 | // requires the item type to be default-constructible (which c10::Device is |
1311 | // not) because in principle it could be required to create new items. Since |
1312 | // we know we'll shrink the vector, we provide a custom dummy value instead. |
1313 | devices.resize(targetIdx, c10::Device(c10::kCPU)); |
1314 | return devices; |
1315 | } |
1316 | |
1317 | static void ensureIsSubsetOfDevices( |
1318 | const std::vector<c10::Device>& subset, |
1319 | const std::vector<c10::Device>& superset) { |
1320 | // We assume the devices in both vectors have the same consistent type, and |
1321 | // their indices are unique and sorted. |
1322 | std::vector<c10::Device> excessDevices; |
1323 | std::set_difference( |
1324 | subset.begin(), |
1325 | subset.end(), |
1326 | superset.begin(), |
1327 | superset.end(), |
1328 | std::back_inserter(excessDevices), |
1329 | [](const c10::Device& a, const c10::Device& b) { return a.index() < b.index(); }); |
1330 | TORCH_CHECK_VALUE( |
1331 | excessDevices.empty(), |
1332 | "The result contained tensors residing on device(s) " , |
1333 | formatSetOfDevices(excessDevices), |
1334 | " which are not among the expected device(s) " , |
1335 | formatSetOfDevices(superset)); |
1336 | } |
1337 | |
1338 | mutable std::mutex mutex_; |
1339 | std::atomic_bool completed_ = {false}; // is this future complete |
1340 | std::condition_variable finished_cv_; |
1341 | |
1342 | IValue value_; // when finished the value |
1343 | TypePtr type_; |
1344 | std::vector<std::function<void(Future&)>> callbacks_; |
1345 | std::exception_ptr eptr_; |
1346 | |
1347 | // An upcast pointer to a virtual class which allows us to manipulate events, |
1348 | // streams, ... in a generic way, without an explicit dependency on CUDA. |
1349 | const c10::impl::VirtualGuardImpl impl_; |
1350 | |
1351 | // The device that was current when markCompleted was called, which we'll |
1352 | // restore when invoking callbacks. It's optional because we'll only store it |
1353 | // if the future completes successfully. |
1354 | optional<c10::Device> currentDevice_; |
1355 | |
1356 | // The events that correspond to the completion of the async I/O kernels. They |
1357 | // are recorded on the appropriate streams when the future is marked completed |
1358 | // and can then be queried/waited/blocked on. There is one event for each |
1359 | // distinct device on which the value's tensors reside. |
1360 | std::vector<c10::Event> events_; |
1361 | |
1362 | // A cached version of the storages extracted from the value when the future |
1363 | // is first marked completed. |
1364 | std::vector<WeakStorage> storages_; |
1365 | |
1366 | // The bounding set of devices that this future, and any of its children, is |
1367 | // allowed to use. This is a superset of the set of devices used by the events |
1368 | // above. We need this to know what streams (for which devices) to set as |
1369 | // current when invoking a callback, thus allowing the callback to use devices |
1370 | // that the parent future didn't use. This field is set to the value provided |
1371 | // in the constructor and will be "inherited" by all child futures. |
1372 | const std::vector<c10::Device> devices_; |
1373 | }; |
1374 | |
1375 | struct C10_EXPORT ivalue::Await final : c10::intrusive_ptr_target { |
1376 | private: |
1377 | explicit Await(TypePtr elType, std::function<IValue()> fn) |
1378 | : elType_(std::move(elType)), type_(AwaitType::create(elType_)), fn_(std::move(fn)) {} |
1379 | |
1380 | explicit Await(TypePtr elType) : elType_(std::move(elType)), type_(AwaitType::create(elType_)) { } |
1381 | |
1382 | friend c10::intrusive_ptr<Await>; |
1383 | |
1384 | public: |
1385 | Await(const Await&) = delete; |
1386 | Await(Await&&) = delete; |
1387 | Await& operator=(const Await&) = delete; |
1388 | Await& operator=(Await&&) = delete; |
1389 | |
1390 | IValue wait() { |
1391 | if (!completed_) { |
1392 | TORCH_CHECK(fn_, "Incompleted Await: fn can't be None" ); |
1393 | value_ = fn_(); |
1394 | completed_ = true; |
1395 | args_ = {}; |
1396 | } |
1397 | return value_; |
1398 | } |
1399 | |
1400 | IValue value() { |
1401 | TORCH_CHECK(completed_, "Await must be completed" ); |
1402 | return value_; |
1403 | } |
1404 | |
1405 | void setFn(std::function<IValue()> fn) { |
1406 | fn_ = std::move(fn); |
1407 | } |
1408 | |
1409 | bool completed() { |
1410 | return completed_; |
1411 | } |
1412 | |
1413 | void markCompleted(IValue value) { |
1414 | value_ = std::move(value); |
1415 | completed_ = true; |
1416 | } |
1417 | |
1418 | TORCH_API friend std::ostream& operator<<( |
1419 | std::ostream& out, |
1420 | const Await& v); |
1421 | |
1422 | TypePtr elementType() const { |
1423 | return elType_; |
1424 | } |
1425 | |
1426 | TypePtr type() const { |
1427 | return type_; |
1428 | } |
1429 | |
1430 | void setArgs(std::vector<IValue> args) { |
1431 | args_ = std::move(args); |
1432 | } |
1433 | |
1434 | std::vector<IValue>& args() { |
1435 | return args_; |
1436 | } |
1437 | |
1438 | private: |
1439 | TypePtr elType_; |
1440 | TypePtr type_; |
1441 | std::vector<IValue> args_; |
1442 | std::function<IValue()> fn_; |
1443 | IValue value_; |
1444 | bool completed_{}; |
1445 | }; |
1446 | |
1447 | // Input is a list of Futures with the same target type. |
1448 | // Output is a Future to the List of completed Futures. |
1449 | TORCH_API intrusive_ptr<ivalue::Future> collectAll( |
1450 | c10::List<c10::intrusive_ptr<ivalue::Future>> srcs); |
1451 | // Input is a List of Futures with the same target type. |
1452 | // Output is a Future that will be updated with a seen value. |
1453 | TORCH_API intrusive_ptr<ivalue::Future> collectAny( |
1454 | c10::List<c10::intrusive_ptr<ivalue::Future>> srcs); |
1455 | |
1456 | // User-defined object. |
1457 | struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target { |
1458 | public: |
1459 | // In general, class types hold a shared_ptr to its owning CompilationUnit, |
1460 | // so that its type and methods do not get deallocated while the class exists. |
1461 | // However, the CompilationUnit holds ownership of the type's graphs, so |
1462 | // inserting a constant object into a Graph would create a reference cycle if |
1463 | // that constant object held a shared_ptr to its CU. For these objects we |
1464 | // instatiate them with non-owning references to its CU |
1465 | Object(WeakOrStrongTypePtr type, size_t numSlots) : type_(std::move(type)) { |
1466 | slots_.resize(numSlots); |
1467 | } |
1468 | |
1469 | Object(StrongTypePtr type, size_t numSlots) |
1470 | : type_(WeakOrStrongTypePtr(std::move(type))) { |
1471 | slots_.resize(numSlots); |
1472 | } |
1473 | |
1474 | static c10::intrusive_ptr<Object> create( |
1475 | WeakOrStrongTypePtr type, |
1476 | size_t numSlots) { |
1477 | return c10::make_intrusive<Object>(std::move(type), numSlots); |
1478 | } |
1479 | |
1480 | static c10::intrusive_ptr<Object> create( |
1481 | StrongTypePtr type, |
1482 | size_t numSlots) { |
1483 | return c10::make_intrusive<Object>(std::move(type), numSlots); |
1484 | } |
1485 | |
1486 | static c10::intrusive_ptr<Object> create(ClassTypePtr classType, size_t numSlots); |
1487 | |
1488 | /** |
1489 | * Slot API. |
1490 | * |
1491 | * Attributes are stored as a simple vector so that lookups are fast at |
1492 | * runtime. A "slot" is just an index into that vector, which can be computed |
1493 | * statically if you have access to the class type. Use this API if you are |
1494 | * writing compiler stuff. |
1495 | */ |
1496 | void setSlot(size_t slot, IValue v) { |
1497 | if (slot >= slots_.size()) { |
1498 | // for module types, it is possible that the members of the class have |
1499 | // expanded after the object was created. In this case, we expand |
1500 | // the slots to the right size |
1501 | resizeObject(slot); |
1502 | } |
1503 | slots_[slot] = std::move(v); |
1504 | } |
1505 | |
1506 | const IValue& getSlot(size_t slot) const { |
1507 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(slot < slots_.size()); |
1508 | // NOTE: This lookup is fairly hot, so we use unchecked access to the |
1509 | // vector. Errors should still be detectable with ASan. |
1510 | return slots_[slot]; |
1511 | } |
1512 | |
1513 | void unsafeRemoveSlot(size_t slot) { |
1514 | TORCH_CHECK(slot < slots_.size()); |
1515 | slots_.erase(slots_.begin() + slot); |
1516 | } |
1517 | |
1518 | /** |
1519 | * Attribute API. |
1520 | * |
1521 | * Wrappers around the slot stuff so that users can access attributes |
1522 | * directly. Use this API if you are a user. |
1523 | * |
1524 | * Note: Unlike in Python, TorchScript must make a distinction between |
1525 | * attributes (which are IValues) and methods (which are Methods). If you |
1526 | * want a method, use `obj.type()->getMethod()` |
1527 | */ |
1528 | IValue getAttr(const std::string& name) const; |
1529 | void setAttr(const std::string& name, IValue v); |
1530 | // Remove attribute by name, caller is responsible for |
1531 | // the safety of this operation |
1532 | // We didn't remove the attribute in the type because the type |
1533 | // might be shared by multiple objects. |
1534 | // Therefore after removing attribute, the object is in an inconsistent |
1535 | // state where it has more attribute types in its Type than |
1536 | // the attribute slots it has, user needs to make sure the object |
1537 | // has consistent by removing the attribute in type as well |
1538 | void unsafeRemoveAttr(const std::string& name); |
1539 | |
1540 | std::string name() const; |
1541 | |
1542 | const std::vector<IValue>& slots() const { |
1543 | return slots_; |
1544 | } |
1545 | std::shared_ptr<ClassType> type() const; |
1546 | |
1547 | std::shared_ptr<torch::jit::CompilationUnit> compilation_unit() { |
1548 | if (type_.holds_strong_ref()) { |
1549 | return type_.cu_.getStrongRefOrThrow(); |
1550 | } else { |
1551 | auto weak_ptr = type_.cu_.getWeakRefOrThrow(); |
1552 | return std::shared_ptr<torch::jit::CompilationUnit>(weak_ptr); |
1553 | } |
1554 | } |
1555 | |
1556 | c10::intrusive_ptr<Object> copy_to_weak_compilation_ref() const; |
1557 | |
1558 | void unsafe_make_weak_compilation_ref() { |
1559 | type_ = WeakOrStrongTypePtr(type_.asWeakTypePtr()); |
1560 | } |
1561 | |
1562 | c10::intrusive_ptr<Object> copy() const; |
1563 | |
1564 | c10::intrusive_ptr<Object> deepcopy() const; |
1565 | |
1566 | c10::intrusive_ptr<Object> deepcopy(IValue::HashAliasedIValueMap& memo) const; |
1567 | |
1568 | bool is_weak_compilation_ref() const { |
1569 | return !type_.holds_strong_ref(); |
1570 | } |
1571 | |
1572 | bool is_empty_strong_compilation_ref() const { |
1573 | return type_.holds_empty_strong_ref(); |
1574 | } |
1575 | |
1576 | private: |
1577 | void resizeObject(size_t slot); |
1578 | WeakOrStrongTypePtr type_; |
1579 | std::vector<IValue> slots_; |
1580 | }; |
1581 | |
1582 | // virtual ivalue PyObjectHolder that hold a py::object, we make this virtual |
1583 | // because the py::object and refcounting logic should happen in libtorch_python |
1584 | // see concrete implementation in python_ivalue.h |
1585 | struct ivalue::PyObjectHolder : c10::intrusive_ptr_target { |
1586 | public: |
1587 | virtual PyObject* getPyObject() = 0; |
1588 | virtual c10::InferredType tryToInferType() = 0; |
1589 | virtual IValue toIValue(const TypePtr& type, c10::optional<int32_t> N = c10::nullopt) = 0; |
1590 | virtual std::string toStr() = 0; |
1591 | virtual std::vector<at::Tensor> () = 0; |
1592 | |
1593 | ~PyObjectHolder() override = default; |
1594 | }; |
1595 | |
1596 | struct ivalue::EnumHolder : c10::intrusive_ptr_target { |
1597 | public: |
1598 | EnumHolder(std::shared_ptr<EnumType> type, std::string name, IValue value) |
1599 | : type_(std::move(type)), |
1600 | name_(std::move(name)), |
1601 | value_(std::move(value)) {} |
1602 | |
1603 | bool is(const ivalue::EnumHolder& rhs) { |
1604 | return *this == rhs; |
1605 | } |
1606 | |
1607 | friend bool operator==( |
1608 | const ivalue::EnumHolder& lhs, |
1609 | const ivalue::EnumHolder& rhs); |
1610 | |
1611 | TORCH_API friend std::ostream& operator<<( |
1612 | std::ostream& out, |
1613 | const EnumHolder& v); |
1614 | |
1615 | TORCH_API const std::string qualifiedClassName() const; |
1616 | |
1617 | const std::string unqualifiedClassName() const; |
1618 | |
1619 | const std::string& name() const { |
1620 | return name_; |
1621 | } |
1622 | |
1623 | const IValue& value() const { |
1624 | return value_; |
1625 | } |
1626 | |
1627 | std::shared_ptr<EnumType> type() const { |
1628 | return type_; |
1629 | } |
1630 | |
1631 | private: |
1632 | std::shared_ptr<EnumType> type_; |
1633 | std::string name_; |
1634 | IValue value_; |
1635 | }; |
1636 | |
1637 | #undef TORCH_FORALL_TAGS |
1638 | |
1639 | namespace detail { |
1640 | |
1641 | struct _guarded_unsigned_long_unique_dummy final { |
1642 | _guarded_unsigned_long_unique_dummy(int64_t){}; |
1643 | }; |
1644 | using _guarded_unsigned_long = std::conditional_t< |
1645 | std::is_same<unsigned long, uint32_t>::value || |
1646 | std::is_same<unsigned long, uint64_t>::value, |
1647 | _guarded_unsigned_long_unique_dummy, |
1648 | unsigned long>; |
1649 | |
1650 | } // namespace detail |
1651 | |
1652 | inline ivalue::Object& IValue::toObjectRef() const { |
1653 | AT_ASSERT(isObject(), "Expected Object but got " , tagKind()); |
1654 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), "Attempted to create null reference" ); |
1655 | return *static_cast<c10::ivalue::Object*>(payload.u.as_intrusive_ptr); |
1656 | } |
1657 | |
1658 | // note: when adding a DEFINE_TO case here you should also add a |
1659 | // toX method to IValue. These named methods are much more discoverable |
1660 | // than the to templated function. |
1661 | |
1662 | #define DEFINE_TO(T, method_name) \ |
1663 | template <> \ |
1664 | inline T IValue::to<T>()&& { \ |
1665 | return static_cast<T>(std::move(*this).method_name()); \ |
1666 | } \ |
1667 | template <> \ |
1668 | inline c10::detail::ivalue_to_const_ref_overload_return<T>::type IValue::to<T>() const& { \ |
1669 | typedef c10::detail::ivalue_to_const_ref_overload_return<T>::type return_type; \ |
1670 | return static_cast<return_type>(this->method_name()); \ |
1671 | } |
1672 | |
1673 | DEFINE_TO(at::Tensor, toTensor) |
1674 | DEFINE_TO(at::Storage, toStorage) |
1675 | DEFINE_TO(c10::Stream, toStream) |
1676 | DEFINE_TO(float, toDouble) |
1677 | DEFINE_TO(double, toDouble) |
1678 | DEFINE_TO(c10::complex<double>, toComplexDouble) |
1679 | DEFINE_TO(unsigned char, toInt) |
1680 | DEFINE_TO(signed char, toInt) |
1681 | DEFINE_TO(unsigned short, toInt) |
1682 | DEFINE_TO(short, toInt) |
1683 | DEFINE_TO(int, toInt) |
1684 | DEFINE_TO(uint32_t, toInt) |
1685 | DEFINE_TO(uint64_t, toInt) |
1686 | DEFINE_TO(detail::_guarded_unsigned_long, toInt) |
1687 | DEFINE_TO(int64_t, toInt) |
1688 | DEFINE_TO(bool, toBool) |
1689 | DEFINE_TO(c10::intrusive_ptr<caffe2::Blob>, toBlob); |
1690 | DEFINE_TO(c10::intrusive_ptr<ivalue::ConstantString>, toString) |
1691 | DEFINE_TO(c10::intrusive_ptr<ivalue::Object>, toObject) |
1692 | DEFINE_TO(at::Scalar, toScalar) |
1693 | DEFINE_TO(c10::List<int64_t>, toIntList) |
1694 | DEFINE_TO(c10::List<double>, toDoubleList) |
1695 | DEFINE_TO(c10::List<c10::complex<double>>, toComplexDoubleList) |
1696 | DEFINE_TO(c10::List<bool>, toBoolList) |
1697 | DEFINE_TO(c10::List<at::Tensor>, toTensorList) |
1698 | DEFINE_TO(c10::impl::GenericList, toList) |
1699 | DEFINE_TO(c10::impl::GenericDict, toGenericDict) |
1700 | DEFINE_TO(c10::intrusive_ptr<ivalue::Tuple>, toTuple) |
1701 | DEFINE_TO(std::string, toStringRef) |
1702 | DEFINE_TO(c10::string_view, toStringView) |
1703 | DEFINE_TO(c10::intrusive_ptr<ivalue::Future>, toFuture) |
1704 | DEFINE_TO(c10::intrusive_ptr<ivalue::Await>, toAwait) |
1705 | DEFINE_TO(c10::intrusive_ptr<c10::RRefInterface>, toRRef) |
1706 | DEFINE_TO(c10::intrusive_ptr<at::Quantizer>, toQuantizer) |
1707 | DEFINE_TO(IValue, toIValue) |
1708 | DEFINE_TO(c10::Device, toDevice) |
1709 | DEFINE_TO(at::ScalarType, toScalarType) |
1710 | DEFINE_TO(at::Layout, toLayout) |
1711 | DEFINE_TO(at::MemoryFormat, toMemoryFormat) |
1712 | DEFINE_TO(at::QScheme, toQScheme) |
1713 | DEFINE_TO(at::Dimname, toDimname) |
1714 | DEFINE_TO(at::Generator, toGenerator) |
1715 | DEFINE_TO(c10::SymInt, toSymInt) |
1716 | DEFINE_TO(c10::SymFloat, toSymFloat) |
1717 | |
1718 | template <class T> |
1719 | struct _fake_type {}; |
1720 | |
1721 | // generic_to<T> converts an IValue from a generic list or generic dict |
1722 | // to a concrete list/dict type likelike List<T>, Dict<...> or optional<T>. |
1723 | // Note that in the case of lists, this only works for IValue-based lists, |
1724 | // i.e. not for int64_t, double, ... |
1725 | // generic_to<T> is an implementation detail of IValue::to<T> and not |
1726 | // supposed to be called directly. |
1727 | // The _fake_type<T> parameter allows us to overload |
1728 | // based on the return type. |
1729 | template <class Elem> |
1730 | // TODO this is deprecated but we don't throw a warning because a lot of ops in |
1731 | // native_functions.yaml still return std::vector. |
1732 | // C10_DEPRECATED_MESSAGE("IValues based on std::vector<T> are potentially slow |
1733 | // and deprecated. Please use torch::List<T> instead.") |
1734 | std::vector<Elem> generic_to(IValue ivalue, _fake_type<std::vector<Elem>>) { |
1735 | // We need to do a deep copy of the vector because there might be other |
1736 | // references to this same IValue that also use the list. We can't just |
1737 | // move the elements out. |
1738 | auto list = std::move(ivalue).to<List<Elem>>(); |
1739 | std::vector<Elem> result; |
1740 | result.reserve(list.size()); |
1741 | for (Elem v : list) { |
1742 | result.push_back(std::move(v)); |
1743 | } |
1744 | return result; |
1745 | } |
1746 | |
1747 | template <typename T> |
1748 | c10::intrusive_ptr<T> IValue::toCustomClass() && { |
1749 | static_assert( |
1750 | std::is_base_of<torch::CustomClassHolder, T>::value == true, |
1751 | "toCustomClass requires that template parameter T must inherit " |
1752 | "from torch::CustomClassHolder" ); |
1753 | auto obj = toObject(); |
1754 | TORCH_CHECK( |
1755 | obj->slots().size() == 1, |
1756 | "Tried to cast IValue to custom class but it did " |
1757 | "not contain a custom class!" ); |
1758 | const auto* expected_type = c10::getCustomClassType<c10::intrusive_ptr<T>>().get(); |
1759 | ivalue::checkCustomClassType(expected_type, type().get()); |
1760 | auto userObj = |
1761 | c10::static_intrusive_pointer_cast<T>(obj->getSlot(0).toCapsule()); |
1762 | return userObj; |
1763 | } |
1764 | |
1765 | template <typename T> |
1766 | c10::intrusive_ptr<T> IValue::toCustomClass() const& { |
1767 | static_assert( |
1768 | std::is_base_of<torch::CustomClassHolder, T>::value == true, |
1769 | "toCustomClass requires that template parameter T must inherit " |
1770 | "from torch::CustomClassHolder" ); |
1771 | auto obj = toObject(); |
1772 | TORCH_CHECK( |
1773 | obj->slots().size() == 1, |
1774 | "Tried to cast IValue to custom class but it did " |
1775 | "not contain a custom class!" ); |
1776 | const auto* expected_type = c10::getCustomClassType<c10::intrusive_ptr<T>>().get(); |
1777 | ivalue::checkCustomClassType(expected_type, type().get()); |
1778 | auto userObj = |
1779 | c10::static_intrusive_pointer_cast<T>(obj->getSlot(0).toCapsule()); |
1780 | return userObj; |
1781 | } |
1782 | |
1783 | template <typename T> |
1784 | T generic_to(IValue ivalue, _fake_type<T>) { |
1785 | using ElemType = typename std::remove_pointer<T>::type::element_type; |
1786 | return std::move(ivalue).toCustomClass<ElemType>(); |
1787 | } |
1788 | |
1789 | template <typename T> |
1790 | tagged_capsule<T> generic_to(IValue ivalue, _fake_type<tagged_capsule<T>>) { |
1791 | return tagged_capsule<T>{std::move(ivalue)}; |
1792 | } |
1793 | |
1794 | template <typename Elem> |
1795 | c10::List<Elem> generic_to(IValue ivalue, _fake_type<c10::List<Elem>>) { |
1796 | return impl::toTypedList<Elem>(std::move(ivalue).toList()); |
1797 | } |
1798 | |
1799 | template <typename T> |
1800 | static T createVectorLikeFromList(const c10::detail::ListImpl* impl) { |
1801 | T result; |
1802 | result.reserve(impl->list.size()); |
1803 | for (const auto & i : impl->list) { |
1804 | result.push_back(i.to<typename T::value_type>()); |
1805 | } |
1806 | return result; |
1807 | } |
1808 | |
1809 | template <typename T> |
1810 | static std::vector<T> createVectorFromList(const c10::detail::ListImpl* impl) { |
1811 | return createVectorLikeFromList<std::vector<T>>(impl); |
1812 | } |
1813 | |
1814 | template <typename T> |
1815 | std::vector<T> createVectorFromList(const c10::List<T>& impl) { |
1816 | std::vector<T> result; |
1817 | result.reserve(impl.size()); |
1818 | for (size_t i = 0, N = impl.size(); i < N; ++i) { |
1819 | result.push_back(impl[i]); |
1820 | } |
1821 | return result; |
1822 | } |
1823 | |
1824 | template <typename T> |
1825 | OptionalArray<T> generic_to(IValue ivalue, _fake_type<OptionalArray<T>>) { |
1826 | if (ivalue.isNone()) { |
1827 | return {}; |
1828 | } |
1829 | return createVectorFromList<T>( |
1830 | std::move(ivalue).to<c10::List<T>>() |
1831 | ); |
1832 | } |
1833 | |
1834 | namespace detail { |
1835 | template <typename Elem, size_t... I> |
1836 | std::array<Elem, sizeof...(I)> generic_to_array( |
1837 | IValue ivalue, |
1838 | _fake_type<std::array<Elem, sizeof...(I)>>, |
1839 | std::index_sequence<I...>) { |
1840 | // We need to do a deep copy of the array because there might be other |
1841 | // references to this same IValue that also use the list. We can't just |
1842 | // move the elements out. |
1843 | auto list = std::move(ivalue).to<List<Elem>>(); |
1844 | TORCH_CHECK( |
1845 | list.size() == sizeof...(I), |
1846 | "Tried to convert a List with " , |
1847 | list.size(), |
1848 | " elements to a fixed-size array of size " , |
1849 | sizeof...(I)); |
1850 | return {list[I]...}; |
1851 | } |
1852 | } // namespace detail |
1853 | |
1854 | template <typename Elem, size_t N> |
1855 | std::array<Elem, N> generic_to( |
1856 | IValue ivalue, |
1857 | _fake_type<std::array<Elem, N>> ft) { |
1858 | return detail::generic_to_array(ivalue, ft, std::make_index_sequence<N>()); |
1859 | } |
1860 | |
1861 | template <typename Key, typename Value> |
1862 | c10::Dict<Key, Value> generic_to( |
1863 | IValue ivalue, |
1864 | _fake_type<c10::Dict<Key, Value>>) { |
1865 | return impl::toTypedDict<Key, Value>(std::move(ivalue).toGenericDict()); |
1866 | } |
1867 | |
1868 | template <typename K, typename V> |
1869 | C10_DEPRECATED_MESSAGE( |
1870 | "IValues based on std::unordered_map are slow and deprecated. Please use c10::Dict<K, V> instead." ) |
1871 | std::unordered_map<K, V> generic_to( |
1872 | IValue ivalue, |
1873 | _fake_type<std::unordered_map<K, V>>) { |
1874 | std::unordered_map<K, V> specialized_dict; |
1875 | |
1876 | for (const auto& item : std::move(ivalue).toGenericDict()) { |
1877 | specialized_dict[item.key().template to<K>()] = item.value().template to<V>(); |
1878 | } |
1879 | |
1880 | return specialized_dict; |
1881 | } |
1882 | |
1883 | template <typename T> |
1884 | c10::optional<T> generic_to(IValue ivalue, _fake_type<c10::optional<T>>) { |
1885 | if (ivalue.isNone()) { |
1886 | return c10::nullopt; |
1887 | } |
1888 | return std::move(ivalue).to<T>(); |
1889 | } |
1890 | |
1891 | namespace detail { |
1892 | template <typename Tuple, std::size_t... INDEX> |
1893 | Tuple generic_to_tuple_impl( |
1894 | const ivalue::TupleElements& t, |
1895 | std::index_sequence<INDEX...>) { |
1896 | return std::make_tuple( |
1897 | t[INDEX].to<typename std::tuple_element<INDEX, Tuple>::type>()...); |
1898 | } |
1899 | } // namespace detail |
1900 | |
1901 | template < |
1902 | typename... Args, |
1903 | typename Indices = std::make_index_sequence<sizeof...(Args)>, |
1904 | std::enable_if_t< |
1905 | !guts::disjunction< |
1906 | std::is_lvalue_reference<Args>..., |
1907 | guts::negation<std::is_constructible<IValue, Args>>...>::value, |
1908 | std::nullptr_t> = nullptr> |
1909 | std::tuple<Args...> generic_to(IValue ivalue, _fake_type<std::tuple<Args...>>) { |
1910 | const auto& vals = ivalue.toTupleRef().elements(); |
1911 | TORCH_CHECK(vals.size() == sizeof...(Args)); |
1912 | return detail::generic_to_tuple_impl<std::tuple<Args...>>(vals, Indices{}); |
1913 | } |
1914 | |
1915 | template <typename T> |
1916 | inline T IValue::to() && { |
1917 | return generic_to(std::move(*this), _fake_type<T>{}); |
1918 | } |
1919 | |
1920 | template <> |
1921 | inline c10::optional<c10::string_view> IValue::to() && { |
1922 | // In the default implementation, the IValue is destroyed with std::move. |
1923 | // But if the unboxed type is optional<string_view> we cannot destroy |
1924 | // the IValue. |
1925 | return generic_to(*this, _fake_type<c10::optional<c10::string_view>>{}); |
1926 | } |
1927 | |
1928 | template <typename T> |
1929 | inline typename c10::detail::ivalue_to_const_ref_overload_return<T>::type IValue::to() const& { |
1930 | return generic_to(*this, _fake_type<T>{}); |
1931 | } |
1932 | |
1933 | inline c10::List<int64_t> IValue::toIntList() && { |
1934 | AT_ASSERT(isIntList(), "Expected IntList but got " , tagKind()); |
1935 | return c10::List<int64_t>(moveToIntrusivePtr<c10::detail::ListImpl>()); |
1936 | } |
1937 | inline c10::List<int64_t> IValue::toIntList() const& { |
1938 | AT_ASSERT(isIntList(), "Expected IntList but got " , tagKind()); |
1939 | return c10::List<int64_t>(toIntrusivePtr<c10::detail::ListImpl>()); |
1940 | } |
1941 | inline std::vector<int64_t> IValue::toIntVector() const { |
1942 | AT_ASSERT(isIntList(), "Expected IntList but got " , tagKind()); |
1943 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
1944 | payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), |
1945 | "called toIntVector on null intrusive_ptr IValue" ); |
1946 | return createVectorFromList<int64_t>( |
1947 | static_cast<const c10::detail::ListImpl*>(payload.u.as_intrusive_ptr)); |
1948 | } |
1949 | inline at::DimVector IValue::toDimVector() const { |
1950 | AT_ASSERT(isIntList(), "Expected IntList but got " , tagKind()); |
1951 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
1952 | payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), |
1953 | "called toDimVector on null intrusive_ptr IValue" ); |
1954 | return createVectorLikeFromList<at::DimVector>( |
1955 | static_cast<const c10::detail::ListImpl*>(payload.u.as_intrusive_ptr)); |
1956 | } |
1957 | inline c10::List<double> IValue::toDoubleList() && { |
1958 | AT_ASSERT(isDoubleList(), "Expected DoubleList but got " , tagKind()); |
1959 | return c10::List<double>(moveToIntrusivePtr<c10::detail::ListImpl>()); |
1960 | } |
1961 | inline c10::List<double> IValue::toDoubleList() const& { |
1962 | AT_ASSERT(isDoubleList(), "Expected DoubleList but got " , tagKind()); |
1963 | return c10::List<double>(toIntrusivePtr<c10::detail::ListImpl>()); |
1964 | } |
1965 | inline std::vector<double> IValue::toDoubleVector() const { |
1966 | AT_ASSERT(isDoubleList(), "Expected DoubleList but got " , tagKind()); |
1967 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
1968 | payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), |
1969 | "called toDoubleVector on null intrusive_ptr IValue" ); |
1970 | return createVectorFromList<double>( |
1971 | static_cast<const c10::detail::ListImpl*>(payload.u.as_intrusive_ptr)); |
1972 | } |
1973 | inline c10::List<c10::complex<double>> IValue::toComplexDoubleList() && { |
1974 | AT_ASSERT(isComplexDoubleList(), "Expected ComplexDoubleList but got " , tagKind()); |
1975 | return c10::List<c10::complex<double>>(moveToIntrusivePtr<c10::detail::ListImpl>()); |
1976 | } |
1977 | inline c10::List<c10::complex<double>> IValue::toComplexDoubleList() const& { |
1978 | AT_ASSERT(isComplexDoubleList(), "Expected ComplexDoubleList but got " , tagKind()); |
1979 | return c10::List<c10::complex<double>>(toIntrusivePtr<c10::detail::ListImpl>()); |
1980 | } |
1981 | inline std::vector<c10::complex<double>> IValue::toComplexDoubleVector() const { |
1982 | AT_ASSERT(isComplexDoubleList(), "Expected ComplexDoubleList but got " , tagKind()); |
1983 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
1984 | payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), |
1985 | "called toComplexDoubleVector on null intrusive_ptr IValue" ); |
1986 | return createVectorFromList<c10::complex<double>>( |
1987 | static_cast<const c10::detail::ListImpl*>(payload.u.as_intrusive_ptr)); |
1988 | } |
1989 | inline c10::List<bool> IValue::toBoolList() && { |
1990 | AT_ASSERT(isBoolList(), "Expected BoolList but got " , tagKind()); |
1991 | return c10::List<bool>(moveToIntrusivePtr<c10::detail::ListImpl>()); |
1992 | } |
1993 | inline c10::List<bool> IValue::toBoolList() const& { |
1994 | AT_ASSERT(isBoolList(), "Expected BoolList but got " , tagKind()); |
1995 | return c10::List<bool>(toIntrusivePtr<c10::detail::ListImpl>()); |
1996 | } |
1997 | inline c10::List<at::Tensor> IValue::toTensorList() && { |
1998 | AT_ASSERT(isTensorList(), "Expected TensorList but got " , tagKind()); |
1999 | return c10::List<at::Tensor>(moveToIntrusivePtr<c10::detail::ListImpl>()); |
2000 | } |
2001 | inline c10::List<at::Tensor> IValue::toTensorList() const& { |
2002 | AT_ASSERT(isTensorList(), "Expected TensorList but got " , tagKind()); |
2003 | return c10::List<at::Tensor>(toIntrusivePtr<c10::detail::ListImpl>()); |
2004 | } |
2005 | inline std::vector<at::Tensor> IValue::toTensorVector() const { |
2006 | AT_ASSERT(isTensorList(), "Expected TensorList but got " , tagKind()); |
2007 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
2008 | payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), |
2009 | "called toTensorVector on null intrusive_ptr IValue" ); |
2010 | return createVectorFromList<at::Tensor>( |
2011 | static_cast<const c10::detail::ListImpl*>(payload.u.as_intrusive_ptr)); |
2012 | } |
2013 | inline c10::List<c10::optional<at::Tensor>> IValue::toOptionalTensorList() && { |
2014 | AT_ASSERT(isOptionalTensorList(), "Expected OptionalTensorList but got " , tagKind()); |
2015 | return c10::List<c10::optional<at::Tensor>>(moveToIntrusivePtr<c10::detail::ListImpl>()); |
2016 | } |
2017 | inline c10::List<c10::optional<at::Tensor>> IValue::toOptionalTensorList() const& { |
2018 | AT_ASSERT(isOptionalTensorList(), "Expected OptionalTensorList but got " , tagKind()); |
2019 | return c10::List<c10::optional<at::Tensor>>(toIntrusivePtr<c10::detail::ListImpl>()); |
2020 | } |
2021 | inline std::vector<c10::optional<at::Tensor>> IValue::toOptionalTensorVector() const { |
2022 | AT_ASSERT(isOptionalTensorList(), "Expected OptionalTensorList but got " , tagKind()); |
2023 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
2024 | payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), |
2025 | "called toOptionalTensorVector on null intrusive_ptr IValue" ); |
2026 | return createVectorFromList<c10::optional<at::Tensor>>( |
2027 | static_cast<const c10::detail::ListImpl*>(payload.u.as_intrusive_ptr)); |
2028 | } |
2029 | inline c10::List<IValue> IValue::toList() && { |
2030 | AT_ASSERT(isList(), "Expected GenericList but got " , tagKind()); |
2031 | return c10::List<IValue>(moveToIntrusivePtr<c10::detail::ListImpl>()); |
2032 | } |
2033 | inline c10::List<IValue> IValue::toList() const& { |
2034 | AT_ASSERT(isList(), "Expected GenericList but got " , tagKind()); |
2035 | return c10::List<IValue>(toIntrusivePtr<c10::detail::ListImpl>()); |
2036 | } |
2037 | inline c10::ArrayRef<IValue> IValue::toListRef() const { |
2038 | AT_ASSERT(isList(), "Expected GenericList but got " , tagKind()); |
2039 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
2040 | payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), |
2041 | "called toListRef on null intrusive_ptr IValue" ); |
2042 | return static_cast<const c10::detail::ListImpl*>(payload.u.as_intrusive_ptr) |
2043 | ->list; |
2044 | } |
2045 | inline c10::Dict<IValue, IValue> IValue::toGenericDict() && { |
2046 | AT_ASSERT(isGenericDict(), "Expected GenericDict but got " , tagKind()); |
2047 | return c10::Dict<IValue, IValue>(moveToIntrusivePtr<c10::detail::DictImpl>()); |
2048 | } |
2049 | inline c10::Dict<IValue, IValue> IValue::toGenericDict() const& { |
2050 | AT_ASSERT(isGenericDict(), "Expected GenericDict but got " , tagKind()); |
2051 | return c10::Dict<IValue, IValue>(toIntrusivePtr<c10::detail::DictImpl>()); |
2052 | } |
2053 | inline c10::intrusive_ptr<ivalue::Tuple> IValue::toTuple() && { |
2054 | AT_ASSERT(isTuple(), "Expected Tuple but got " , tagKind()); |
2055 | return moveToIntrusivePtr<ivalue::Tuple>(); |
2056 | } |
2057 | inline c10::intrusive_ptr<ivalue::Tuple> IValue::toTuple() const& { |
2058 | AT_ASSERT(isTuple(), "Expected Tuple but got " , tagKind()); |
2059 | return toIntrusivePtr<ivalue::Tuple>(); |
2060 | } |
2061 | inline ivalue::Tuple& IValue::toTupleRef() const { |
2062 | AT_ASSERT(isTuple(), "Expected Tuple but got " , tagKind()); |
2063 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
2064 | payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), |
2065 | "called toTupleRef on null intrusive_ptr IValue" ); |
2066 | return *static_cast<c10::ivalue::Tuple*>( |
2067 | payload.u.as_intrusive_ptr); |
2068 | } |
2069 | |
2070 | inline IValue::IValue(c10::intrusive_ptr<ivalue::Tuple> v) |
2071 | : tag(Tag::Tuple) { |
2072 | payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); |
2073 | } |
2074 | template < |
2075 | typename... Args, |
2076 | std::enable_if_t< |
2077 | !guts::disjunction< |
2078 | std::is_lvalue_reference<Args>..., |
2079 | guts::negation<std::is_constructible<IValue, Args>>...>::value, |
2080 | std::nullptr_t>> |
2081 | inline IValue::IValue(const std::tuple<Args...>& t) |
2082 | : IValue( |
2083 | std::move(c10::guts::apply(c10::ivalue::Tuple::create<const Args&...>, t))) { |
2084 | } |
2085 | |
2086 | template < |
2087 | typename... Args, |
2088 | std::enable_if_t< |
2089 | !guts::disjunction< |
2090 | std::is_lvalue_reference<Args>..., |
2091 | guts::negation<std::is_constructible<IValue, Args>>...>::value, |
2092 | std::nullptr_t>> |
2093 | inline IValue::IValue(std::tuple<Args...>&& t) |
2094 | : IValue( |
2095 | std::move(c10::guts::apply(c10::ivalue::Tuple::create<Args&&...>, std::move(t)))) { |
2096 | } |
2097 | |
2098 | inline IValue::IValue(c10::intrusive_ptr<ivalue::ConstantString> v) |
2099 | : tag(Tag::String) { |
2100 | payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); |
2101 | } |
2102 | inline IValue::IValue(std::string v) |
2103 | : IValue(ivalue::ConstantString::create(std::move(v))) {} |
2104 | |
2105 | inline IValue::IValue(c10::impl::GenericList v) |
2106 | : tag(Tag::GenericList) { |
2107 | payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.impl_.release()); |
2108 | } |
2109 | |
2110 | template <class T, IValue::enable_if_list_is_ivalue_constructible<T>> |
2111 | inline IValue::IValue(c10::List<T>&& v) : IValue(impl::toList<T>(std::move(v))) {} |
2112 | template <class T, IValue::enable_if_list_is_ivalue_constructible<T>> |
2113 | inline IValue::IValue(const c10::List<T>& v) : IValue(impl::toList<T>(v)) {} |
2114 | template <class T, IValue::enable_if_list_is_ivalue_constructible<T>> |
2115 | inline IValue::IValue(at::ArrayRef<T> v) : IValue(c10::List<T>()) { |
2116 | auto list = to<c10::List<T>>(); |
2117 | list.reserve(v.size()); |
2118 | for (const auto& e : v) { |
2119 | list.push_back(e); |
2120 | } |
2121 | } |
2122 | template <class T, IValue::enable_if_symint<T>> |
2123 | inline IValue::IValue(at::ArrayRef<T> v) : IValue() { |
2124 | auto vi = c10::asIntArrayRefSlowOpt(v); |
2125 | if (vi.has_value()) { |
2126 | // This list is entirely integers; ensure it is typed as |
2127 | // an IntList so toIntList works |
2128 | *this = IValue(*vi); |
2129 | } else { |
2130 | // This list has SymInts; type it as a SymInt |
2131 | *this = IValue(impl::toList<c10::SymInt>(c10::List<c10::SymInt>())); |
2132 | auto list = to<c10::List<c10::SymInt>>(); |
2133 | list.reserve(v.size()); |
2134 | for (const auto& e : v) { |
2135 | list.push_back(e); |
2136 | } |
2137 | } |
2138 | } |
2139 | template <class T, IValue::enable_if_symint<T>> |
2140 | inline IValue::IValue(at::OptionalArrayRef<T> mb_v) : IValue() { |
2141 | if (!mb_v.has_value()) return; |
2142 | *this = IValue(*mb_v); |
2143 | } |
2144 | template <class T, IValue::enable_if_symint<T>> |
2145 | inline IValue::IValue(const std::vector<T>& v) : IValue() { |
2146 | *this = IValue(at::ArrayRef<T>(v)); |
2147 | } |
2148 | template <class T, IValue::enable_if_list_is_ivalue_constructible<T>> |
2149 | inline IValue::IValue(const std::vector<T>& v) : IValue(c10::List<T>()) { |
2150 | auto list = to<c10::List<T>>(); |
2151 | list.reserve(v.size()); |
2152 | for (const auto& e : v) { |
2153 | list.push_back(e); |
2154 | } |
2155 | } |
2156 | template <class T, IValue::enable_if_list_is_ivalue_constructible<T>> |
2157 | inline IValue::IValue(c10::OptionalArrayRef<T> v) : IValue() { |
2158 | if (v.has_value()) { |
2159 | *this = IValue(std::move(*v)); |
2160 | } |
2161 | } |
2162 | |
2163 | template <class T, size_t N> |
2164 | inline IValue::IValue(std::array<T, N> v) : IValue(c10::List<T>()) { |
2165 | auto list = to<c10::List<T>>(); |
2166 | list.reserve(v.size()); |
2167 | for (auto& e : v) { |
2168 | list.push_back(std::move(e)); |
2169 | } |
2170 | } |
2171 | |
2172 | template <class T, IValue::enable_if_ilist_is_ivalue_constructible<T>> |
2173 | inline IValue::IValue(c10::IListRef<T> v) : IValue() { |
2174 | constexpr bool boxed_type_constructs_ivalue = |
2175 | std::is_constructible<IValue, typename c10::IListRef<T>::boxed_type>::value; |
2176 | // First, we try to use the boxed value. |
2177 | // If we fail (either it's not in the boxed state, or its boxed type |
2178 | // can not construct an IValue), we fallback to copying the list. |
2179 | if (boxed_type_constructs_ivalue && v.isBoxed()) { |
2180 | *this = IValue(impl::toList(v.toBoxed())); |
2181 | } else { |
2182 | c10::List<T> list; |
2183 | list.reserve(v.size()); |
2184 | for (const auto& t : v) { |
2185 | list.push_back(t); |
2186 | } |
2187 | *this = IValue(impl::toList(std::move(list))); |
2188 | } |
2189 | } |
2190 | |
2191 | inline IValue::IValue(c10::impl::GenericDict v) |
2192 | : tag(Tag::GenericDict) { |
2193 | payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.impl_.release()); |
2194 | } |
2195 | template <class Key, class Value> |
2196 | inline IValue::IValue(c10::Dict<Key, Value> v) |
2197 | : IValue(impl::toGenericDict(std::move(v))) {} |
2198 | |
2199 | template <class Key, class Value> |
2200 | inline IValue::IValue(std::unordered_map<Key, Value> v) |
2201 | : IValue(Dict<Key, Value>()) { |
2202 | auto dict = to<c10::Dict<Key, Value>>(); |
2203 | dict.reserve(v.size()); |
2204 | for (auto& e : v) { |
2205 | dict.insert(std::move(e.first), std::move(e.second)); |
2206 | } |
2207 | } |
2208 | |
2209 | template <class T, IValue::enable_if_ivalue_constructible<T>> |
2210 | inline IValue::IValue(c10::optional<T> v) : IValue() { |
2211 | if (v.has_value()) { |
2212 | *this = IValue(std::move(*v)); |
2213 | } |
2214 | } |
2215 | |
2216 | inline IValue::IValue(c10::nullopt_t) : IValue() {} |
2217 | |
2218 | inline IValue::IValue(c10::intrusive_ptr<ivalue::Object> v) |
2219 | : tag(Tag::Object) { |
2220 | payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); |
2221 | } |
2222 | |
2223 | inline IValue::IValue(c10::intrusive_ptr<ivalue::PyObjectHolder> v) |
2224 | : tag(Tag::PyObject) { |
2225 | payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); |
2226 | } |
2227 | |
2228 | inline IValue::IValue(c10::intrusive_ptr<ivalue::EnumHolder> v) |
2229 | : tag(Tag::Enum) { |
2230 | payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); |
2231 | } |
2232 | |
2233 | inline IValue IValue::make_capsule( |
2234 | intrusive_ptr<torch::CustomClassHolder> blob) { |
2235 | IValue iv; |
2236 | iv.tag = Tag::Capsule; |
2237 | iv.payload.u.as_intrusive_ptr = null_to_undefined_tensor(blob.release()); |
2238 | return iv; |
2239 | } |
2240 | |
2241 | template < |
2242 | typename T, |
2243 | std::enable_if_t<std::is_base_of<torch::CustomClassHolder, T>::value, int>> |
2244 | IValue::IValue(c10::intrusive_ptr<T> custom_class) : tag(Tag::Object) { |
2245 | auto classType = []() { |
2246 | try { |
2247 | return c10::getCustomClassType<c10::intrusive_ptr<T>>(); |
2248 | } catch (const c10::Error&) { |
2249 | throw c10::Error( |
2250 | "Trying to instantiate a class that isn't a registered custom class: " + |
2251 | std::string(c10::util::get_fully_qualified_type_name<T>()), |
2252 | "" ); |
2253 | } |
2254 | }(); |
2255 | auto ivalue_obj = c10::ivalue::Object::create(std::move(classType), /* numSlots */1); |
2256 | ivalue_obj->setSlot(0, IValue::make_capsule(std::move(custom_class))); |
2257 | payload.u.as_intrusive_ptr = null_to_undefined_tensor(ivalue_obj.release()); |
2258 | |
2259 | } |
2260 | |
2261 | inline IValue::IValue(c10::intrusive_ptr<ivalue::Future> v) |
2262 | : tag(Tag::Future) { |
2263 | payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); |
2264 | } |
2265 | |
2266 | inline IValue::IValue(c10::intrusive_ptr<ivalue::Await> v) |
2267 | : tag(Tag::Await) { |
2268 | payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); |
2269 | } |
2270 | |
2271 | inline IValue::IValue(c10::intrusive_ptr<c10::RRefInterface> v) |
2272 | : tag(Tag::RRef) { |
2273 | payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); |
2274 | } |
2275 | |
2276 | inline IValue::IValue(c10::intrusive_ptr<at::Quantizer> v) |
2277 | : tag(Tag::Quantizer) { |
2278 | payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); |
2279 | } |
2280 | |
2281 | template <typename T> |
2282 | inline IValue::IValue(c10::complex<T> c) |
2283 | : tag(Tag::ComplexDouble) { |
2284 | auto v = c10::make_intrusive<ivalue::ComplexHolder>(c); |
2285 | payload.u.as_intrusive_ptr = v.release(); |
2286 | } |
2287 | |
2288 | inline const std::string& IValue::toStringRef() const { |
2289 | AT_ASSERT(isString(), "Expected String but got " , tagKind()); |
2290 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
2291 | payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), |
2292 | "called toStringRef on null intrusive_ptr IValue" ); |
2293 | return static_cast<const c10::ivalue::ConstantString*>( |
2294 | payload.u.as_intrusive_ptr) |
2295 | ->string(); |
2296 | } |
2297 | inline c10::optional<std::reference_wrapper<const std::string>> IValue:: |
2298 | toOptionalStringRef() const { |
2299 | if (isNone()) { |
2300 | return c10::nullopt; |
2301 | } |
2302 | AT_ASSERT(isString(), "Expected optional<string> but got " , tagKind()); |
2303 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
2304 | payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), |
2305 | "called toOptionalStringRef on null intrusive_ptr IValue" ); |
2306 | return std::reference_wrapper<const std::string>( |
2307 | static_cast<const c10::ivalue::ConstantString*>(payload.u.as_intrusive_ptr) |
2308 | ->string()); |
2309 | } |
2310 | |
2311 | inline c10::string_view IValue::toStringView() const { |
2312 | AT_ASSERT(isString(), "Expected String but got " , tagKind()); |
2313 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
2314 | payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), |
2315 | "called toStringView on null intrusive_ptr IValue" ); |
2316 | return static_cast<const c10::ivalue::ConstantString*>( |
2317 | payload.u.as_intrusive_ptr) |
2318 | ->string_view(); |
2319 | } |
2320 | |
2321 | inline PyObject* IValue::toPyObject() const { |
2322 | return toPyObjectHolder()->getPyObject(); |
2323 | } |
2324 | |
2325 | template <typename T> |
2326 | inline optional<T> IValue::toOptional() { |
2327 | if (this->isNone()) { |
2328 | return nullopt; |
2329 | } |
2330 | return this->to<T>(); |
2331 | } |
2332 | |
2333 | template <typename T> |
2334 | inline optional<T> IValue::toOptional() const { |
2335 | if (this->isNone()) { |
2336 | return nullopt; |
2337 | } |
2338 | return this->to<T>(); |
2339 | } |
2340 | |
2341 | inline bool IValue::isCustomClass() const { |
2342 | return torch::isCustomClass(*this); |
2343 | } |
2344 | |
2345 | inline bool IValue::isSameIdentity(const IValue& rhs) const { |
2346 | // We choose to not use memcmp for payload check due to potential random |
2347 | // padding characters on union type |
2348 | |
2349 | // Semantics: |
2350 | // 1. Immutable primitive values of the same type (Int, Double, None, Bool, |
2351 | // Str) return value equality |
2352 | // 2. If it is a tensor type, we need to take undefined tensor into account |
2353 | // 3. Undefined_tensor is None and vice versa should be true |
2354 | // 4. If it is a reference type (i.e. isIntrusivePtr()), then is True when |
2355 | // the pointed-to object is the same. |
2356 | // 5. False for all other comparisons. |
2357 | if (this->isNone() && rhs.isNone()) { |
2358 | return true; |
2359 | } else if (this->isBool() && rhs.isBool()) { |
2360 | // for bool type, do equality check |
2361 | return this->toBool() == rhs.toBool(); |
2362 | } else if (this->isTensor() && rhs.isTensor()) { |
2363 | return this->payload.as_tensor.is_same(rhs.payload.as_tensor); |
2364 | } else if (this->isTensor() && rhs.isNone()) { |
2365 | // special case: undefined tensor and None are the same identity |
2366 | return !this->payload.as_tensor.defined(); |
2367 | } else if (this->isNone() && rhs.isTensor()) { |
2368 | // special case: undefined tensor and None are the same identity |
2369 | return !rhs.payload.as_tensor.defined(); |
2370 | } else if (this->isInt() && rhs.isInt()) { |
2371 | return this->toInt() == rhs.toInt(); |
2372 | } else if (this->isDouble() && rhs.isDouble()) { |
2373 | return this->toDouble() == rhs.toDouble(); |
2374 | } else if (this->isString() && rhs.isString()) { |
2375 | return this->toStringRef() == rhs.toStringRef(); |
2376 | } else { |
2377 | // for objects holding in IValue, do shallow compare on pointer address to |
2378 | // testify the identity |
2379 | return this->isIntrusivePtr() && rhs.isIntrusivePtr() && |
2380 | this->payload.u.as_intrusive_ptr == rhs.payload.u.as_intrusive_ptr; |
2381 | } |
2382 | } |
2383 | |
2384 | namespace ivalue { |
2385 | namespace detail { |
2386 | |
2387 | template <typename T> |
2388 | IValue from_(T&& x, std::true_type) { |
2389 | return IValue(std::forward<T>(x)); |
2390 | } |
2391 | template <typename T> |
2392 | IValue from_(c10::intrusive_ptr<T> x, std::false_type) { |
2393 | return IValue(std::move(x)); |
2394 | } |
2395 | template <typename T> |
2396 | IValue from_(T&& /*x*/, std::false_type) { |
2397 | static_assert( |
2398 | guts::false_t<T>::value, |
2399 | "You are calling from with a type that it doesn't support, and isn't a potential custom class (ie: is an intrusive_ptr)" ); |
2400 | return IValue(); |
2401 | } |
2402 | } // namespace detail |
2403 | |
2404 | template <typename T> |
2405 | IValue from(T&& x) { |
2406 | return detail::from_( |
2407 | std::forward<T>(x), typename std::is_constructible<IValue, T>::type{}); |
2408 | } |
2409 | |
2410 | } // namespace ivalue |
2411 | |
2412 | |
2413 | template <> |
2414 | struct MaybeOwnedTraits<IValue> { |
2415 | using owned_type = IValue; |
2416 | using borrow_type = IValue; |
2417 | |
2418 | static borrow_type createBorrow(const owned_type& from) { |
2419 | if (!from.isPtrType()) { |
2420 | return from; |
2421 | } |
2422 | if (from.isTensor()) { |
2423 | return IValue(MaybeOwnedTraits<at::Tensor>::createBorrow(from.toTensor())); |
2424 | } else { |
2425 | return IValue(from.payload, from.tag); |
2426 | } |
2427 | } |
2428 | |
2429 | static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) { |
2430 | lhs.clearToNone(); |
2431 | if (!rhs.isPtrType()) { |
2432 | lhs = rhs; |
2433 | } else if (rhs.isTensor()) { |
2434 | lhs = IValue(MaybeOwnedTraits<at::Tensor>::createBorrow(rhs.toTensor())); |
2435 | } else { |
2436 | lhs = IValue(rhs.payload, rhs.tag); |
2437 | } |
2438 | } |
2439 | |
2440 | static void destroyBorrow(borrow_type& toDestroy) { |
2441 | toDestroy.clearToNone(); |
2442 | } |
2443 | |
2444 | static const owned_type& referenceFromBorrow(const borrow_type& borrow) { |
2445 | return borrow; |
2446 | } |
2447 | |
2448 | static const owned_type* pointerFromBorrow(const borrow_type& borrow) { |
2449 | return &borrow; |
2450 | } |
2451 | |
2452 | static bool debugBorrowIsValid(const borrow_type&) { |
2453 | return true; |
2454 | } |
2455 | }; |
2456 | |
2457 | template <> |
2458 | struct IValue::TagType<c10::Type> { |
2459 | static TORCH_API c10::TypePtr get(const IValue&); |
2460 | }; |
2461 | |
2462 | template <> |
2463 | struct IValue::TagType<c10::DynamicType> { |
2464 | static TORCH_API c10::TypePtr get(const IValue&); |
2465 | }; |
2466 | |
2467 | template <typename T> |
2468 | TypePtr IValue::type() const { |
2469 | return IValue::TagType<T>::get(*this); |
2470 | } |
2471 | |
2472 | } // namespace c10 |
2473 | |