1 | #pragma once |
2 | |
3 | #include <ATen/core/custom_class.h> |
4 | #include <ATen/core/jit_type_base.h> |
5 | #include <ATen/core/TensorBody.h> |
6 | #include <ATen/core/functional.h> |
7 | #include <ATen/core/symbol.h> |
8 | #include <ATen/core/type_factory.h> |
9 | #include <ATen/core/qualified_name.h> |
10 | #include <c10/util/TypeList.h> |
11 | #include <c10/util/Optional.h> |
12 | #include <c10/core/SymFloat.h> |
13 | |
14 | #include <array> |
15 | #include <memory> |
16 | #include <ostream> |
17 | #include <sstream> |
18 | #include <type_traits> |
19 | #include <utility> |
20 | |
21 | namespace torch { |
22 | namespace jit { |
23 | struct Function; |
24 | } // namespace jit |
25 | } // namespace torch |
26 | |
27 | namespace c10 { |
28 | |
29 | template<class Key, class Value> |
30 | class Dict; |
31 | struct IValue; |
32 | struct FunctionSchema; |
33 | struct NamedType; |
34 | using OptNameList = c10::optional<std::vector<std::string>>; |
35 | |
36 | void standardizeVectorForUnion(std::vector<TypePtr>& reference, std::vector<TypePtr>* to_fill); |
37 | void standardizeVectorForUnion(std::vector<TypePtr>* to_flatten); |
38 | |
39 | inline bool is_contiguous_strides( |
40 | const IntArrayRef sizes, |
41 | const IntArrayRef strides) { |
42 | int n_dim = static_cast<int>(sizes.size()); |
43 | if (n_dim == 0) { |
44 | return true; |
45 | } |
46 | |
47 | if (strides[n_dim - 1] != 1) { |
48 | return false; |
49 | } |
50 | |
51 | for (int i = n_dim - 2; i >= 0; i--) { |
52 | if (strides[i] != strides[i + 1] * sizes[i + 1]) { |
53 | return false; |
54 | } |
55 | } |
56 | return true; |
57 | } |
58 | |
59 | struct AnyType; |
60 | using AnyTypePtr = SingletonTypePtr<AnyType>; |
61 | // Any is the top of the type hierarchy, all other types are subtypes |
62 | // T <: Any, forall T |
63 | struct TORCH_API AnyType : public Type { |
64 | bool equals(const Type& rhs) const override { |
65 | return rhs.kind() == kind(); |
66 | } |
67 | std::string str() const override { |
68 | return "Any" ; |
69 | } |
70 | static const TypeKind Kind = TypeKind::AnyType; |
71 | // global singleton |
72 | static AnyTypePtr get(); |
73 | |
74 | private: |
75 | AnyType() : Type(TypeKind::AnyType) {} |
76 | }; |
77 | |
78 | inline std::string toString(const Type& type) { |
79 | return type.str(); |
80 | } |
81 | |
82 | // Shim for compatibility with code that uses TypePtr. |
83 | inline std::string toString(const TypePtr& typePtr) { |
84 | return toString(*typePtr); |
85 | } |
86 | |
87 | inline bool operator!=(const Type& lhs, const Type& rhs) { |
88 | return !(lhs == rhs); |
89 | } |
90 | |
91 | // common base for all types that have a single sub element |
92 | // e.g. Future[T], Optional[T], List[T] |
93 | template <TypeKind K, typename T> |
94 | struct SingleElementType : public SharedType { |
95 | static const TypeKind Kind = K; |
96 | |
97 | const TypePtr& getElementType() const { |
98 | return elem; |
99 | } |
100 | |
101 | bool hasFreeVariables() const override { |
102 | return getElementType()->hasFreeVariables(); |
103 | } |
104 | |
105 | at::ArrayRef<TypePtr> containedTypes() const override { |
106 | return elem; |
107 | } |
108 | |
109 | bool equals(const Type& rhs) const override { |
110 | if (auto rhs_ = rhs.cast<T>()) { |
111 | return *getElementType() == *rhs_->getElementType(); |
112 | } |
113 | return false; |
114 | } |
115 | |
116 | protected: |
117 | SingleElementType(TypePtr elem) : SharedType(Kind), elem(std::move(elem)) { |
118 | if (!this->elem) { |
119 | throw std::runtime_error(c10::str( |
120 | "Can not create " , typeKindToString(Kind), " with None type" )); |
121 | } |
122 | } |
123 | |
124 | private: |
125 | TypePtr elem; |
126 | }; |
127 | |
128 | struct UnionType; |
129 | using UnionTypePtr = std::shared_ptr<UnionType>; |
130 | struct TORCH_API UnionType : public SharedType { |
131 | friend struct Type; |
132 | |
133 | static const TypeKind Kind = TypeKind::UnionType; |
134 | |
135 | bool isSubtypeOfExt(const Type& rhs_, std::ostream* why_not) const override; |
136 | |
137 | std::string str() const override; |
138 | |
139 | static UnionTypePtr create(std::vector<TypePtr> reference); |
140 | |
141 | bool equals(const Type& rhs) const override; |
142 | |
143 | bool isUnionType() const override { |
144 | return true; |
145 | } |
146 | |
147 | at::ArrayRef<TypePtr> containedTypes() const override { |
148 | return types_; |
149 | } |
150 | |
151 | // For testing purposes only |
152 | at::ArrayRef<TypePtr> getTypes() const { |
153 | return types_; |
154 | } |
155 | |
156 | TypePtr createWithContained(std::vector<TypePtr> contained_types) const override { |
157 | return create(std::move(contained_types)); |
158 | } |
159 | |
160 | bool canHoldType(const Type& type) const; |
161 | |
162 | bool hasFreeVariables() const override { |
163 | return has_free_variables_; |
164 | } |
165 | |
166 | c10::optional<TypePtr> toOptional() const; |
167 | |
168 | c10::optional<TypePtr> subtractTypeSet(std::vector<TypePtr>& to_subtract) const; |
169 | |
170 | protected: |
171 | explicit UnionType(std::vector<TypePtr> types, TypeKind kind=TypeKind::UnionType); |
172 | std::string annotation_str_impl(TypePrinter printer = nullptr) const override; |
173 | std::string unionStr( |
174 | TypePrinter printer = nullptr, |
175 | bool is_annotation_str = false) const; |
176 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
177 | bool has_free_variables_; |
178 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
179 | std::vector<TypePtr> types_; |
180 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
181 | bool can_hold_none_; |
182 | |
183 | }; |
184 | |
185 | struct OptionalType; |
186 | using OptionalTypePtr = std::shared_ptr<OptionalType>; |
187 | // This type represents an optional type. There is one `Optional` for |
188 | // each element type. `Optional[T]` can accept both `T` and |
189 | // `None`(`c10::nullopt` in C++) |
190 | // Subtype hierarchy for Optional: |
191 | // - Optional[T] <: Optional[R] iff T <: R |
192 | // - T <: Optional[R] if T <: R |
193 | // - None <: Optional[T] for all T |
194 | // - Optional[T] == Union[T, None] for all T |
195 | struct TORCH_API OptionalType : public UnionType { |
196 | static OptionalTypePtr create(TypePtr contained); |
197 | |
198 | static const TypeKind Kind = TypeKind::OptionalType; |
199 | |
200 | friend struct Type; |
201 | |
202 | bool equals(const Type& rhs) const override; |
203 | |
204 | const TypePtr& getElementType() const { |
205 | return contained_; |
206 | } |
207 | |
208 | at::ArrayRef<TypePtr> containedTypes() const override { |
209 | return contained_; |
210 | } |
211 | |
212 | std::string str() const override { |
213 | std::stringstream ss; |
214 | ss << getElementType()->str() << "?" ; |
215 | return ss.str(); |
216 | } |
217 | |
218 | TypePtr createWithContained( |
219 | std::vector<TypePtr> contained_types) const override { |
220 | AT_ASSERT(contained_types.size() == 1); |
221 | return create(std::move(contained_types[0])); |
222 | } |
223 | |
224 | bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override; |
225 | |
226 | bool isUnionType() const override { |
227 | return true; |
228 | } |
229 | |
230 | // common cast Optional[Tensor] for undefined tensor type |
231 | static TypePtr ofTensor(); |
232 | // |
233 | // global singleton |
234 | static TypePtr get(TypePtr inner); |
235 | |
236 | private: |
237 | explicit OptionalType(TypePtr contained); |
238 | |
239 | TypePtr contained_; |
240 | |
241 | std::string annotation_str_impl(TypePrinter printer = nullptr) const override { |
242 | std::stringstream ss; |
243 | ss << "Optional[" << getElementType()->annotation_str(std::move(printer)) << "]" ; |
244 | return ss.str(); |
245 | } |
246 | }; |
247 | |
248 | template <typename T> |
249 | inline c10::optional<T> merge_primitive( |
250 | const c10::optional<T>& a, |
251 | const c10::optional<T>& b) { |
252 | if (a.has_value() && b.has_value() && a.value() == b.value()) { |
253 | return a; |
254 | } |
255 | return c10::optional<T>{}; |
256 | } |
257 | |
258 | // If we see `a + b + c` and know that a, b, and c are the same size and have |
259 | // two dimensions (WxH), then we can generate a fused kernel for them. That |
260 | // fused kernel would likely have indexing math to handling both the W and H |
261 | // dimensions. However, if we knew the WxH dimensions were contiguous, we can |
262 | // pretend like we only have a single dimension, simplifying the indexing logic. |
263 | // This can be performed even if the dimensions are transposed, |
264 | // as long as a, b, and c are transposed in the same way. |
265 | // We'd like to have the compiler be able to do this dimensionality reduction, |
266 | // but simply knowing sizes is not enough. |
267 | // We can extend profiling to also record stride information. |
268 | // Rather than recording specific strides, |
269 | // we can simply order the strides from smallest to largest with |
270 | // `stride_indices` A contiguity marker on the smallest stride (c0) indicates |
271 | // the stride is precisely 1, otherwise a contiguity marker means that $stride_n |
272 | // = size_{n-1}*stride_{n-1}$ |
273 | struct TORCH_API Stride { |
274 | Stride() = default; |
275 | Stride( |
276 | const c10::optional<size_t>& stride_index, |
277 | c10::optional<bool> contiguous, |
278 | const c10::optional<size_t>& stride) |
279 | : stride_index_(stride_index), contiguous_(contiguous), stride_(stride) {} |
280 | |
281 | bool operator==(const Stride& b) const { |
282 | return stride_index_ == b.stride_index_ && contiguous_ == b.contiguous_ && |
283 | stride_ == b.stride_; |
284 | } |
285 | |
286 | bool isComplete() const { |
287 | return stride_index_ && contiguous_ && stride_; |
288 | } |
289 | |
290 | c10::optional<size_t> stride_index_; |
291 | c10::optional<bool> contiguous_; |
292 | c10::optional<size_t> stride_; |
293 | }; |
294 | |
295 | template <> |
296 | inline c10::optional<Stride> merge_primitive( |
297 | const c10::optional<Stride>& a, |
298 | const c10::optional<Stride>& b) { |
299 | c10::optional<Stride> left = a; |
300 | c10::optional<Stride> right = b; |
301 | if (!left.has_value()) { |
302 | left = {Stride()}; |
303 | } |
304 | if (!right.has_value()) { |
305 | right = {Stride()}; |
306 | } |
307 | |
308 | auto merged_index = |
309 | merge_primitive(left->stride_index_, right->stride_index_); |
310 | auto merged_cont = merge_primitive(left->contiguous_, right->contiguous_); |
311 | auto merged_stride = merge_primitive(left->stride_, right->stride_); |
312 | auto r = Stride(merged_index, merged_cont, merged_stride); |
313 | // normalize |
314 | if (!r.stride_index_.has_value() && !r.contiguous_.has_value() && |
315 | !r.stride_.has_value()) { |
316 | return c10::optional<Stride>{}; |
317 | } |
318 | |
319 | return r; |
320 | } |
321 | |
322 | struct TORCH_API ShapeSymbol { |
323 | // needed for use in `std::map` |
324 | ShapeSymbol() : value_(-1) {} |
325 | // is this symbol a fixed/static dimension |
326 | bool is_static() const { |
327 | return value_ >= 0; |
328 | }; |
329 | bool operator==(const ShapeSymbol& b) const { |
330 | return value_ == b.value_; |
331 | } |
332 | bool operator<(const ShapeSymbol& b) const { |
333 | return value_ < b.value_; |
334 | } |
335 | |
336 | static ShapeSymbol fromStaticSize(int64_t val) { |
337 | return ShapeSymbol(val); |
338 | } |
339 | int64_t static_size() const { |
340 | TORCH_CHECK(is_static()); |
341 | return value_; |
342 | }; |
343 | |
344 | int64_t value() const { |
345 | return value_; |
346 | }; |
347 | |
348 | static ShapeSymbol newSymbol() { |
349 | return fromStaticSize(-static_cast<int64_t>(++num_symbols)); |
350 | }; |
351 | friend TORCH_API std::ostream& operator<<( |
352 | std::ostream& os, |
353 | const ShapeSymbol& s); |
354 | |
355 | private: |
356 | ShapeSymbol(int64_t val) : value_(val) {} |
357 | int64_t value_; |
358 | static std::atomic<size_t> num_symbols; |
359 | }; |
360 | |
361 | inline ShapeSymbol merge_primitive( |
362 | const ShapeSymbol& a, |
363 | const ShapeSymbol& b) { |
364 | if (a.is_static() && b.is_static() && a == b) { |
365 | return a; |
366 | } |
367 | return ShapeSymbol::newSymbol(); |
368 | } |
369 | |
370 | // Shape of a Tensor represented with ShapeSymbol's. Unranked, ranked unknown |
371 | // dims, partially known and fully known shapes are all supported. |
372 | struct TORCH_API SymbolicShape { |
373 | // Unranked shape constructor. |
374 | SymbolicShape() : dims_(c10::nullopt) {} |
375 | |
376 | // Known rank but unknown dimentions. |
377 | SymbolicShape(c10::optional<size_t> rank) : dims_(c10::nullopt) { |
378 | if(!rank) { |
379 | return; |
380 | } |
381 | |
382 | std::vector<ShapeSymbol> shape_symbols; |
383 | shape_symbols.reserve(*rank); |
384 | for(size_t i = 0; i < *rank; ++i) { |
385 | shape_symbols.push_back(ShapeSymbol::newSymbol()); |
386 | } |
387 | dims_ = shape_symbols; |
388 | } |
389 | |
390 | // Mix of known and unknown ranks |
391 | SymbolicShape(const std::vector<c10::optional<int64_t>>& dims) { |
392 | std::vector<ShapeSymbol> shape_symbols; |
393 | shape_symbols.reserve(dims.size()); |
394 | for(c10::optional<int64_t> dim: dims) { |
395 | if(!dim) { |
396 | shape_symbols.push_back(ShapeSymbol::newSymbol()); |
397 | } else { |
398 | shape_symbols.push_back(ShapeSymbol::fromStaticSize(*dim)); |
399 | } |
400 | } |
401 | dims_ = shape_symbols; |
402 | } |
403 | |
404 | void dump() const; |
405 | |
406 | SymbolicShape(std::vector<ShapeSymbol> dims) : dims_(std::move(dims)) {} |
407 | |
408 | SymbolicShape(c10::IntArrayRef dims) { |
409 | std::vector<ShapeSymbol> shape_symbols; |
410 | shape_symbols.reserve(dims.size()); |
411 | for(int64_t dim : dims) { |
412 | shape_symbols.push_back(ShapeSymbol::fromStaticSize(dim)); |
413 | } |
414 | dims_ = shape_symbols; |
415 | } |
416 | |
417 | ShapeSymbol operator[](size_t i) const { |
418 | if (!dims_) { |
419 | throw std::runtime_error("Rank isn't fixed" ); |
420 | } |
421 | return (*dims_).at(i); |
422 | } |
423 | |
424 | ShapeSymbol at(size_t i) const { |
425 | if (!dims_) { |
426 | throw std::runtime_error("Rank isn't fixed" ); |
427 | } |
428 | return (*dims_).at(i); |
429 | } |
430 | |
431 | // Returns rank or nullopt in case of unranked shape. |
432 | c10::optional<size_t> rank() const { |
433 | if(!dims_) { |
434 | return c10::nullopt; |
435 | } |
436 | return dims_->size(); |
437 | } |
438 | |
439 | c10::optional<std::vector<ShapeSymbol>> sizes() const { |
440 | return dims_; |
441 | } |
442 | |
443 | c10::optional<std::vector<bool>> symbolicDims() const { |
444 | if (!dims_) { |
445 | return c10::nullopt; |
446 | } |
447 | auto symbolic_dims = std::vector<bool>(); |
448 | for (const ShapeSymbol& s : *dims_) { |
449 | symbolic_dims.push_back(!s.is_static()); |
450 | } |
451 | return symbolic_dims; |
452 | } |
453 | |
454 | // Checks whether the shape is fully defined/complete, ie. rank and sizes |
455 | // of every dimension are known. |
456 | bool isComplete() const { |
457 | if(!dims_) { |
458 | return false; |
459 | } |
460 | for(auto d : *dims_) { |
461 | if(!d.is_static()) { |
462 | return false; |
463 | } |
464 | } |
465 | return true; |
466 | } |
467 | |
468 | // Create new SymbolicShape that is result of merging self and another |
469 | // SymbolicShape. Only dimensions that are static and equal will be |
470 | // preserved. |
471 | // If either of two shapes are of unknown rank or they have unmatching rank, |
472 | // result will be unranked. |
473 | SymbolicShape merge(const SymbolicShape& other) const; |
474 | |
475 | friend bool operator==(const SymbolicShape& lhs, const SymbolicShape& rhs) { |
476 | return lhs.dims_ == rhs.dims_; |
477 | } |
478 | |
479 | friend bool operator!=(const SymbolicShape& lhs, const SymbolicShape& rhs) { |
480 | return !(lhs == rhs); |
481 | } |
482 | |
483 | private: |
484 | c10::optional<std::vector<ShapeSymbol>> dims_; |
485 | }; |
486 | |
487 | namespace detail { |
488 | inline bool isComplete(const Stride& s) { |
489 | return s.isComplete(); |
490 | } |
491 | |
492 | template<typename T> |
493 | inline bool isComplete(const T& /*t*/) { |
494 | return true; |
495 | } |
496 | } |
497 | |
498 | template <typename T> |
499 | struct VaryingShape { |
500 | using ListOfOptionalElements = std::vector<c10::optional<T>>; |
501 | VaryingShape(const std::vector<T>& vec) |
502 | : VaryingShape(ListOfOptionalElements(vec.begin(), vec.end())) {} |
503 | |
504 | VaryingShape(c10::ArrayRef<T> vec) |
505 | : VaryingShape(ListOfOptionalElements(vec.begin(), vec.end())) {} |
506 | |
507 | VaryingShape(c10::optional<size_t> size = c10::nullopt) : dims_(c10::nullopt) { |
508 | if (size) { |
509 | dims_ = ListOfOptionalElements(*size); |
510 | } |
511 | } |
512 | |
513 | VaryingShape(ListOfOptionalElements dims) : dims_(std::move(dims)) {} |
514 | |
515 | VaryingShape(size_t size) : VaryingShape(c10::optional<size_t>(size)) {} |
516 | |
517 | bool operator==(const VaryingShape& other) const { |
518 | return dims_ == other.dims_; |
519 | } |
520 | |
521 | const c10::optional<T> &operator[](size_t i) const { |
522 | if (!dims_) { |
523 | throw std::runtime_error("Rank isn't fixed" ); |
524 | } |
525 | return (*dims_).at(i); |
526 | } |
527 | |
528 | c10::optional<size_t> size() const { |
529 | if (!dims_) { |
530 | return c10::nullopt; |
531 | } |
532 | const auto& dims = dims_.value(); |
533 | return dims.size(); |
534 | } |
535 | |
536 | const c10::optional<ListOfOptionalElements>& sizes() const { |
537 | return dims_; |
538 | } |
539 | |
540 | TORCH_API VaryingShape merge(const VaryingShape& other) const; |
541 | |
542 | c10::optional<std::vector<T>> concrete_sizes() const { |
543 | if (!dims_) { |
544 | return c10::nullopt; |
545 | } |
546 | std::vector<T> sizes; |
547 | for (auto d : *dims_) { |
548 | if (!d) { |
549 | return c10::nullopt; |
550 | } |
551 | sizes.push_back(d.value()); |
552 | } |
553 | return sizes; |
554 | } |
555 | |
556 | bool isComplete() const { |
557 | if (!dims_) { |
558 | return false; |
559 | } |
560 | for (auto d : *dims_) { |
561 | if (!d || !detail::isComplete(*d)) { |
562 | return false; |
563 | } |
564 | } |
565 | return true; |
566 | } |
567 | |
568 | private: |
569 | c10::optional<ListOfOptionalElements> dims_; |
570 | }; |
571 | |
572 | struct TensorType; |
573 | // TODO: investigate making this SingletonOrSharedTypePtr<TensorType> |
574 | using TensorTypePtr = std::shared_ptr<TensorType>; |
575 | // This type represents a single Tensor with a specific size |
576 | struct TORCH_API TensorType : public SharedType { |
577 | static TensorTypePtr create(const at::Tensor& t); |
578 | |
579 | // used by TensorType::create(size_t dim) which in turn used by |
580 | // shape_analysis.cpp |
581 | static TensorTypePtr create( |
582 | c10::optional<at::ScalarType> scalar_type, |
583 | c10::optional<Device> device, |
584 | const VaryingShape<int64_t>& sizes, |
585 | const VaryingShape<int64_t>& strides, |
586 | c10::optional<bool> requires_grad, |
587 | c10::optional<bool> undefined = false, |
588 | bool tensor_contiguity = false); |
589 | |
590 | static TensorTypePtr create( |
591 | c10::optional<at::ScalarType> scalar_type, |
592 | c10::optional<Device> device, |
593 | const SymbolicShape& sizes, |
594 | const VaryingShape<Stride>& stride_, |
595 | c10::optional<bool> requires_grad, |
596 | c10::optional<bool> undefined = false); |
597 | |
598 | static TensorTypePtr create( |
599 | c10::optional<at::ScalarType> scalar_type, |
600 | c10::optional<Device> device, |
601 | c10::optional<size_t> dim, |
602 | c10::optional<bool> requires_grad); |
603 | |
604 | // overloaded create variadic template argument as it could not distinguish |
605 | // initializer list |
606 | static TensorTypePtr createContiguous( |
607 | at::ScalarType scalar_type, |
608 | at::Device device, |
609 | at::IntArrayRef sizes); |
610 | |
611 | static TypePtr fromNumberType(const Type& typ); |
612 | static TypePtr fromBoolType(); |
613 | |
614 | c10::optional<size_t> dim() const { |
615 | return sizes().size(); |
616 | } |
617 | |
618 | VaryingShape<int64_t> sizes() const; |
619 | |
620 | VaryingShape<int64_t> strides() const; |
621 | |
622 | const VaryingShape<Stride>& stride_properties() const { |
623 | return strides_; |
624 | } |
625 | |
626 | c10::optional<at::Device> device() const { |
627 | return device_; |
628 | } |
629 | c10::optional<at::ScalarType> scalarType() const { |
630 | return scalar_type_; |
631 | } |
632 | c10::optional<bool> requiresGrad() const { |
633 | return requires_grad_; |
634 | } |
635 | bool requires_grad() const override { |
636 | return requires_grad_ ? *requires_grad_ : true; |
637 | } |
638 | |
639 | bool equals(const Type& rhs) const override; |
640 | bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override; |
641 | |
642 | std::string str() const override; |
643 | |
644 | std::string repr_str() const override { |
645 | if (isInferredType()) { |
646 | return str() + " (inferred)" ; |
647 | } else { |
648 | return str(); |
649 | } |
650 | } |
651 | |
652 | c10::optional<size_t> numel() const { |
653 | size_t prod = 1; |
654 | const auto& shape = sizes(); |
655 | |
656 | for (size_t i = 0; i < shape.size(); i++) { |
657 | if (!shape[i]) { |
658 | return c10::optional<size_t>{}; |
659 | } |
660 | prod *= shape[i].value(); |
661 | } |
662 | return prod; |
663 | } |
664 | |
665 | TensorTypePtr withRequiresGrad(c10::optional<bool> s) { |
666 | auto copy = clone(); |
667 | copy->requires_grad_ = s; |
668 | return copy; |
669 | } |
670 | |
671 | TensorTypePtr withScalarType(c10::optional<ScalarType> st) { |
672 | auto copy = clone(); |
673 | copy->scalar_type_ = st; |
674 | return copy; |
675 | } |
676 | |
677 | TensorTypePtr withDim(c10::optional<size_t> d) { |
678 | auto copy = clone(); |
679 | // withDim is only used by the legacy executor |
680 | // that only cares about the rank, so create dummy symbols)) : |
681 | copy->sizes_ = SymbolicShape(d); |
682 | copy->strides_ = VaryingShape<Stride>(d); |
683 | return copy; |
684 | } |
685 | |
686 | TensorTypePtr withStrides(VaryingShape<Stride> sstrides) const { |
687 | auto cloned = clone(); |
688 | cloned->strides_ = sstrides; |
689 | return cloned; |
690 | } |
691 | |
692 | TensorTypePtr withSizesStrides( |
693 | at::IntArrayRef sizes, |
694 | at::IntArrayRef strides) const { |
695 | auto cloned = clone(); |
696 | auto ssizes = SymbolicShape(sizes); |
697 | cloned->sizes_ = ssizes; |
698 | cloned->strides_ = computeStrideProps(sizes, strides); |
699 | return cloned; |
700 | } |
701 | |
702 | TensorTypePtr withSymbolicShapes(SymbolicShape ssizes) const { |
703 | auto cloned = clone(); |
704 | cloned->sizes_ = std::move(ssizes); |
705 | return cloned; |
706 | } |
707 | |
708 | TensorTypePtr withSizes(at::IntArrayRef sizes) const { |
709 | return withSizesStrides( |
710 | sizes, contiguousStridesOf(sizes)); |
711 | } |
712 | |
713 | TensorTypePtr withDevice(const c10::optional<at::Device> device) const { |
714 | auto copy = clone(); |
715 | copy->device_ = device; |
716 | return copy; |
717 | } |
718 | |
719 | TensorTypePtr dimensionedOnly() const { |
720 | auto copy = clone(); |
721 | copy->sizes_ = SymbolicShape(sizes().size()); |
722 | copy->strides_ = VaryingShape<Stride>(sizes().size()); |
723 | return copy; |
724 | } |
725 | |
726 | TensorTypePtr contiguous() const { |
727 | auto cloned = clone(); |
728 | TORCH_INTERNAL_ASSERT(sizes().concrete_sizes().has_value()); |
729 | auto strides = computeStrideProps( |
730 | *sizes().concrete_sizes(), |
731 | contiguousStridesOf(*sizes().concrete_sizes())); |
732 | cloned->strides_ = strides; |
733 | return cloned; |
734 | } |
735 | |
736 | const SymbolicShape& symbolic_sizes() const; |
737 | |
738 | TensorTypePtr merge(const TensorType& other, bool merge_sizes = true) const; |
739 | |
740 | bool matchTensor(const at::Tensor& t); |
741 | |
742 | // is all information about the type specified except for autograd? |
743 | // This replaces the notion of a 'CompleteTensorType' that used to exist |
744 | // in the type-hierarchy. Excluding require_grad and undefined allows |
745 | // this to match the old behavior. |
746 | bool isComplete() const { |
747 | return scalar_type_ && device_ && sizes_.isComplete() && strides_.isComplete(); |
748 | } |
749 | |
750 | bool isInferredType() const { |
751 | return is_inferred_; |
752 | } |
753 | |
754 | static TensorTypePtr getInferred() { |
755 | static auto valueInferred = TensorType::create( |
756 | /*scalar_type=*/{}, |
757 | /*device=*/{}, |
758 | /*sizes=*/SymbolicShape(), |
759 | /*stride=*/VaryingShape<Stride>{}, |
760 | /*requires_grad=*/{}, |
761 | /*undefined=*/false); |
762 | valueInferred->is_inferred_ = true; |
763 | return valueInferred; |
764 | } |
765 | |
766 | // this property is used by GuardElimination |
767 | // please see `checkInputs` for more details |
768 | bool isSummarized() const { |
769 | return !(isComplete() && requiresGrad().has_value() && |
770 | undefined().has_value()); |
771 | } |
772 | |
773 | TensorTypePtr withUndefined() { |
774 | auto r = clone(); |
775 | r->undefined_ = true; |
776 | return r; |
777 | } |
778 | |
779 | TensorTypePtr withPossiblyUndefined() { |
780 | auto r = clone(); |
781 | r->undefined_ = c10::nullopt; |
782 | return r; |
783 | } |
784 | |
785 | c10::optional<bool> undefined() const { return undefined_; } |
786 | |
787 | static const TensorTypePtr& get(); |
788 | |
789 | static const TypeKind Kind = TypeKind::TensorType; |
790 | |
791 | static std::vector<int64_t> contiguousStridesOf( |
792 | at::IntArrayRef in_sizes, |
793 | at::MemoryFormat memory_format = MemoryFormat::Contiguous) { |
794 | auto contiguous_fn = [](const at::IntArrayRef& sizes, |
795 | const std::vector<int64_t>& dim_order) { |
796 | std::vector<int64_t> strides(sizes.size()); |
797 | if (sizes.empty()) // zero-dim case |
798 | return strides; |
799 | |
800 | strides[dim_order[0]] = 1; |
801 | for (size_t i = 1; i < dim_order.size(); i++) { |
802 | auto cur_dim = dim_order[i]; |
803 | auto pre_dim = dim_order[i - 1]; |
804 | strides[cur_dim] = strides[pre_dim] * sizes[pre_dim]; |
805 | } |
806 | return strides; |
807 | }; |
808 | |
809 | std::vector<int64_t> dim_order(in_sizes.size()); |
810 | if (memory_format == MemoryFormat::ChannelsLast) { |
811 | dim_order = {1, 3, 2, 0}; |
812 | } else if (memory_format == MemoryFormat::ChannelsLast3d) { |
813 | dim_order = {1, 4, 3, 2, 0}; |
814 | } else { |
815 | auto ndims = in_sizes.size(); |
816 | for (size_t i = 0; i < ndims; i++) { |
817 | dim_order[i] = ndims - i - 1; // Reverse |
818 | } |
819 | } |
820 | return contiguous_fn(in_sizes, dim_order); |
821 | } |
822 | |
823 | private: |
824 | TensorType( |
825 | c10::optional<at::ScalarType> scalar_type, |
826 | c10::optional<Device> device, |
827 | const SymbolicShape& sizes, |
828 | const VaryingShape<Stride>& strides, |
829 | c10::optional<bool> requires_grad, |
830 | c10::optional<bool> undefined = false); |
831 | |
832 | TensorTypePtr clone() const { |
833 | return TensorTypePtr(new TensorType( |
834 | scalar_type_, device_, sizes_, strides_, requires_grad_, undefined_)); |
835 | } |
836 | |
837 | static VaryingShape<Stride> computeStrideProps( |
838 | at::IntArrayRef sizes, |
839 | at::IntArrayRef strides, |
840 | bool tensor_contiguity = false); |
841 | |
842 | c10::optional<at::ScalarType> scalar_type_; |
843 | c10::optional<at::Device> device_; |
844 | SymbolicShape sizes_; |
845 | VaryingShape<Stride> strides_; |
846 | c10::optional<bool> requires_grad_; |
847 | // we exploit the fact certain tensors must be zero in the autograd to |
848 | // optimize gradient computation. Such zero tensors are currently implemented |
849 | // with `UndefinedTensorImpl.` They can be handled only by special operators |
850 | // (e.g. `AutogradAdd`) and their `Tensor::defined()` property returns false. |
851 | // Normally, `undefined_` is set to false, unless a type was created |
852 | // with `withUndefined` |
853 | // This will also mean that `undefined` tensors will fail |
854 | // `subtypeOf(TensorType::get())` check |
855 | // undefined_ may become `c10::nullopt` if the tensor was observed to be both |
856 | // defined and undefined. However, no tensor type starts out with |
857 | // `undefined_` set to `c10::nullopt` |
858 | c10::optional<bool> undefined_; |
859 | // Represents whether or not this type was inferred. |
860 | bool is_inferred_ = false; |
861 | }; |
862 | |
863 | struct ListType; |
864 | using ListTypePtr = std::shared_ptr<ListType>; |
865 | struct TORCH_API ListType |
866 | : public SingleElementType<TypeKind::ListType, ListType> { |
867 | // It's not exactly a singleton, but there should be exactly one instance of |
868 | // List[T] for every T |
869 | friend struct Type; |
870 | template <typename... T> |
871 | static ListTypePtr create(T&&... all) { |
872 | return ListTypePtr( |
873 | new ListType(std::forward<T>(all)...)); // NOLINT(modernize-make-shared) |
874 | } |
875 | |
876 | std::string str() const override { |
877 | std::stringstream ss; |
878 | ss << getElementType()->str() << "[]" ; |
879 | return ss.str(); |
880 | } |
881 | TypePtr createWithContained( |
882 | std::vector<TypePtr> contained_types) const override { |
883 | return create(std::move(contained_types.at(0))); |
884 | } |
885 | |
886 | bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override; |
887 | |
888 | // global singleton |
889 | // Given an inner type T and an identifier, |
890 | // this function wil return the global singleton type pointer |
891 | // the type List<T>. |
892 | // The extra "identifier" argument is needed beccause we have multiple container types |
893 | // that all re-use this function (List<T>, array<T, N>, etc.) |
894 | static TypePtr get(std::string identifier, TypePtr inner); |
895 | |
896 | // common cast List[Tensor] |
897 | static ListTypePtr ofTensors(); |
898 | static ListTypePtr ofOptionalTensors(); |
899 | static ListTypePtr ofInts(); |
900 | static ListTypePtr ofFloats(); |
901 | static ListTypePtr ofComplexDoubles(); |
902 | static ListTypePtr ofBools(); |
903 | static ListTypePtr ofStrings(); |
904 | |
905 | private: |
906 | ListType(TypePtr elem) : SingleElementType(std::move(elem)) {} |
907 | |
908 | std::string annotation_str_impl(TypePrinter printer = nullptr) const override { |
909 | std::stringstream ss; |
910 | ss << "List[" << getElementType()->annotation_str(std::move(printer)) << "]" ; |
911 | return ss.str(); |
912 | } |
913 | }; |
914 | |
915 | struct DictType; |
916 | using DictTypePtr = std::shared_ptr<DictType>; |
917 | struct TORCH_API DictType : public SharedType { |
918 | friend struct Type; |
919 | static const TypeKind Kind = TypeKind::DictType; |
920 | |
921 | static DictTypePtr create(TypePtr key, TypePtr value) { |
922 | auto kind = key->kind(); |
923 | if (auto dyn = key->castRaw<DynamicType>()) { |
924 | kind = dyn->dynamicKind(); |
925 | } |
926 | switch (kind) { |
927 | case TypeKind::AnyType: |
928 | case TypeKind::IntType: |
929 | case TypeKind::BoolType: |
930 | case TypeKind::FloatType: |
931 | case TypeKind::ComplexType: |
932 | case TypeKind::StringType: |
933 | case TypeKind::TensorType: |
934 | case TypeKind::DeviceObjType: |
935 | return DictTypePtr(new DictType(std::move(key), std::move(value))); |
936 | default: |
937 | AT_ERROR( |
938 | "Cannot create dict for key type '" , |
939 | key->str(), |
940 | "', only int, float, complex, Tensor, device and string keys are supported" ); |
941 | } |
942 | } |
943 | |
944 | // aligned with the format in FunctionSchema |
945 | std::string str() const override { |
946 | std::stringstream ss; |
947 | ss << "Dict(" << getKeyType()->str() << ", " << getValueType()->str() |
948 | << ")" ; |
949 | return ss.str(); |
950 | } |
951 | |
952 | TypePtr createWithContained( |
953 | std::vector<TypePtr> contained_types) const override { |
954 | if (contained_types.size() != 2) { |
955 | throw std::runtime_error("Expected 2 contained types" ); |
956 | } |
957 | return create(std::move(contained_types.at(0)), std::move(contained_types.at(1))); |
958 | } |
959 | |
960 | const TypePtr& getKeyType() const { |
961 | return types.at(0); |
962 | } |
963 | |
964 | const TypePtr& getValueType() const { |
965 | return types.at(1); |
966 | } |
967 | |
968 | bool hasFreeVariables() const override { |
969 | return has_free_variables; |
970 | } |
971 | |
972 | at::ArrayRef<TypePtr> containedTypes() const override { |
973 | return types; |
974 | } |
975 | |
976 | bool equals(const Type& rhs) const override { |
977 | if (auto* dict_rhs = rhs.castRaw<DictType>()) { |
978 | return *getKeyType() == *(dict_rhs->getKeyType()) && |
979 | *getValueType() == *(dict_rhs->getValueType()); |
980 | } |
981 | return false; |
982 | } |
983 | |
984 | // global singleton |
985 | // Given an inner type T and an identifier, |
986 | // this function wil return the global singleton type pointer |
987 | // the type List<T>. |
988 | // The extra "identifier" argument is needed beccause we have multiple container types |
989 | // that all re-use this function (Dict<K, V> and unordered_map<K, V>) |
990 | static TypePtr get(std::string identifier, TypePtr key, TypePtr val); |
991 | |
992 | private: |
993 | DictType(TypePtr key, TypePtr value) |
994 | : SharedType(TypeKind::DictType), |
995 | has_free_variables( |
996 | key->hasFreeVariables() || value->hasFreeVariables()) { |
997 | types.reserve(2); |
998 | types.push_back(std::move(key)); |
999 | types.push_back(std::move(value)); |
1000 | } |
1001 | |
1002 | std::string annotation_str_impl(TypePrinter printer = nullptr) const override { |
1003 | std::stringstream ss; |
1004 | ss << "Dict[" << getKeyType()->annotation_str(printer) << ", " ; |
1005 | ss << getValueType()->annotation_str(std::move(printer)) << "]" ; |
1006 | return ss.str(); |
1007 | } |
1008 | |
1009 | std::vector<TypePtr> types; |
1010 | bool has_free_variables; |
1011 | }; |
1012 | |
1013 | struct FutureType; |
1014 | using FutureTypePtr = std::shared_ptr<FutureType>; |
1015 | |
1016 | struct TORCH_API FutureType |
1017 | : public SingleElementType<TypeKind::FutureType, FutureType> { |
1018 | friend struct Type; |
1019 | template <typename... T> |
1020 | static FutureTypePtr create(TypePtr elem) { |
1021 | return FutureTypePtr( |
1022 | new FutureType(std::move(elem))); // NOLINT(modernize-make-shared) |
1023 | } |
1024 | |
1025 | std::string str() const override { |
1026 | std::stringstream ss; |
1027 | ss << "Future(" << getElementType()->str() << ")" ; |
1028 | return ss.str(); |
1029 | } |
1030 | TypePtr createWithContained( |
1031 | std::vector<TypePtr> contained_types) const override { |
1032 | return create(std::move(contained_types.at(0))); |
1033 | } |
1034 | |
1035 | bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override { |
1036 | if (Type::isSubtypeOfExt(rhs, why_not)) { |
1037 | return true; |
1038 | } |
1039 | if (auto rhs_ = rhs.castRaw<FutureType>()) { |
1040 | return getElementType()->isSubtypeOfExt(*rhs_->getElementType(), why_not); |
1041 | } |
1042 | return false; |
1043 | } |
1044 | |
1045 | private: |
1046 | FutureType(TypePtr elem) : SingleElementType(std::move(elem)) {} |
1047 | |
1048 | std::string annotation_str_impl(TypePrinter printer = nullptr) const override { |
1049 | std::stringstream ss; |
1050 | ss << "Future[" << getElementType()->annotation_str(std::move(printer)) << "]" ; |
1051 | return ss.str(); |
1052 | } |
1053 | }; |
1054 | |
1055 | struct AwaitType; |
1056 | using AwaitTypePtr = std::shared_ptr<AwaitType>; |
1057 | |
1058 | struct TORCH_API AwaitType |
1059 | : public SingleElementType<TypeKind::AwaitType, AwaitType> { |
1060 | friend struct Type; |
1061 | template <typename... T> |
1062 | static AwaitTypePtr create(TypePtr elem) { |
1063 | return AwaitTypePtr( |
1064 | new AwaitType(std::move(elem))); // NOLINT(modernize-make-shared) |
1065 | } |
1066 | |
1067 | std::string str() const override { |
1068 | std::stringstream ss; |
1069 | ss << "Await(" << getElementType()->str() << ")" ; |
1070 | return ss.str(); |
1071 | } |
1072 | TypePtr createWithContained( |
1073 | std::vector<TypePtr> contained_types) const override { |
1074 | return create(std::move(contained_types.at(0))); |
1075 | } |
1076 | |
1077 | bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override { |
1078 | if (Type::isSubtypeOfExt(rhs, why_not)) { |
1079 | return true; |
1080 | } |
1081 | if (auto rhs_ = rhs.castRaw<AwaitType>()) { |
1082 | return getElementType()->isSubtypeOfExt(*rhs_->getElementType(), why_not); |
1083 | } |
1084 | return false; |
1085 | } |
1086 | |
1087 | private: |
1088 | AwaitType(TypePtr elem) : SingleElementType(std::move(elem)) {} |
1089 | |
1090 | std::string annotation_str_impl(TypePrinter printer = nullptr) const override { |
1091 | std::stringstream ss; |
1092 | ss << "Await[" << getElementType()->annotation_str(printer) << "]" ; |
1093 | return ss.str(); |
1094 | } |
1095 | }; |
1096 | |
1097 | struct RRefType; |
1098 | using RRefTypePtr = std::shared_ptr<RRefType>; |
1099 | |
1100 | struct TORCH_API RRefType |
1101 | : public SingleElementType<TypeKind::RRefType, RRefType> { |
1102 | friend struct Type; |
1103 | template <typename... T> |
1104 | static RRefTypePtr create(TypePtr elem) { |
1105 | return RRefTypePtr( |
1106 | new RRefType(std::move(elem))); // NOLINT(modernize-make-shared) |
1107 | } |
1108 | |
1109 | std::string str() const override { |
1110 | std::stringstream ss; |
1111 | ss << "RRef(" << getElementType()->str() << ")" ; |
1112 | return ss.str(); |
1113 | } |
1114 | TypePtr createWithContained( |
1115 | std::vector<TypePtr> contained_types) const override { |
1116 | return create(std::move(contained_types.at(0))); |
1117 | } |
1118 | |
1119 | private: |
1120 | RRefType(TypePtr elem) : SingleElementType(std::move(elem)) {} |
1121 | |
1122 | std::string annotation_str_impl(TypePrinter printer = nullptr) const override { |
1123 | std::stringstream ss; |
1124 | ss << "RRef[" << getElementType()->annotation_str(std::move(printer)) << "]" ; |
1125 | return ss.str(); |
1126 | } |
1127 | }; |
1128 | |
1129 | // Any should never appear in a named type like a class, namedtuple or |
1130 | // interface. If it does, then dynamic type information will be lost in the |
1131 | // Pickler, leading to hard-to-track-down bugs that will only occur |
1132 | // after saving or loading a model. This is because we rely on the |
1133 | // static types in named types to reconstruct type tags of loaded |
1134 | // values. Lifting this restriction requires solving the serialization |
1135 | // problem first. |
1136 | TORCH_API void checkNoAny( |
1137 | const Type& base, |
1138 | const char* what, |
1139 | const std::string& attrname, |
1140 | const TypePtr& attrtype); |
1141 | |
1142 | struct TupleType; |
1143 | using TupleTypePtr = std::shared_ptr<TupleType>; |
1144 | using NameList = std::vector<std::string>; |
1145 | // This type represents a Tuple |
1146 | struct TORCH_API TupleType : public NamedType { |
1147 | |
1148 | static TupleTypePtr createNamed(const c10::optional<c10::QualifiedName>& name, |
1149 | const std::vector<std::string>& field_names, |
1150 | const std::vector<TypePtr>& field_types, |
1151 | std::vector<IValue>& field_defaults); |
1152 | |
1153 | static TupleTypePtr createNamed(const c10::optional<c10::QualifiedName>& name, |
1154 | const std::vector<std::string>& field_names, |
1155 | const std::vector<TypePtr>& field_types); |
1156 | |
1157 | static TupleTypePtr createNamed(const c10::optional<c10::QualifiedName>& name, |
1158 | const std::vector<c10::string_view>& field_names, |
1159 | const std::vector<TypePtr>& field_types); |
1160 | |
1161 | static TupleTypePtr create( |
1162 | std::vector<TypePtr> types) { |
1163 | return TupleTypePtr(new TupleType( |
1164 | std::move(types), |
1165 | c10::nullopt, |
1166 | nullptr)); // NOLINT(modernize-make-shared) |
1167 | } |
1168 | static TupleTypePtr create() { |
1169 | return create({}); |
1170 | } |
1171 | |
1172 | at::ArrayRef<TypePtr> elements() const { |
1173 | return elements_; |
1174 | } |
1175 | |
1176 | bool equals(const Type& rhs) const override; |
1177 | bool isSubtypeOfExt(const Type& rhs_, std::ostream* why_not) const override; |
1178 | |
1179 | std::string str() const override; |
1180 | bool hasFreeVariables() const override { |
1181 | return has_free_variables_; |
1182 | } |
1183 | at::ArrayRef<TypePtr> containedTypes() const override { |
1184 | return elements_; |
1185 | } |
1186 | TypePtr createWithContained( |
1187 | std::vector<TypePtr> contained_types) const override { |
1188 | return std::shared_ptr<TupleType>( |
1189 | new TupleType(std::move(contained_types), name(), schema())); |
1190 | } |
1191 | const std::shared_ptr<FunctionSchema>& schema() const { |
1192 | return schema_; |
1193 | } |
1194 | c10::optional<std::vector<c10::string_view>> names() const; |
1195 | |
1196 | static const TypeKind Kind = TypeKind::TupleType; |
1197 | |
1198 | private: |
1199 | template <typename S> |
1200 | static TupleTypePtr createWithSpec( |
1201 | const c10::optional<c10::QualifiedName>& name, |
1202 | const std::vector<S>& field_names, |
1203 | const std::vector<TypePtr>& field_types, |
1204 | std::vector<IValue>& field_defaults); |
1205 | |
1206 | TupleType( |
1207 | std::vector<TypePtr> elements_, |
1208 | c10::optional<c10::QualifiedName> name, |
1209 | std::shared_ptr<FunctionSchema> schema); |
1210 | |
1211 | bool compare( |
1212 | const Type& rhs, |
1213 | std::function<bool(const Type&, const Type&)> fn) const { |
1214 | if (rhs.kind() != kind()) { |
1215 | return false; |
1216 | } |
1217 | |
1218 | const auto& l_elements = elements(); |
1219 | const auto& r_elements = rhs.castRaw<TupleType>()->elements(); |
1220 | if (l_elements.size() != r_elements.size()) |
1221 | return false; |
1222 | for (size_t i = 0; i < l_elements.size(); ++i) { |
1223 | if (!fn(*l_elements[i], *r_elements[i])) |
1224 | return false; |
1225 | } |
1226 | return true; |
1227 | } |
1228 | |
1229 | std::string annotation_str_impl(TypePrinter printer = nullptr) const override; |
1230 | |
1231 | std::vector<TypePtr> elements_; |
1232 | bool has_free_variables_; |
1233 | std::shared_ptr<FunctionSchema> schema_; |
1234 | }; |
1235 | |
1236 | // the common supertype of all Enums, only used in operator registraion. |
1237 | // EnumType <: AnyEnumType for all Enums |
1238 | struct AnyEnumType; |
1239 | using AnyEnumTypePtr = SingletonTypePtr<AnyEnumType>; |
1240 | struct TORCH_API AnyEnumType final : public Type { |
1241 | bool equals(const Type& rhs) const override { |
1242 | return rhs.kind() == kind(); |
1243 | } |
1244 | std::string str() const override { |
1245 | return "AnyEnumType" ; |
1246 | } |
1247 | static const TypeKind Kind = TypeKind::AnyEnumType; |
1248 | // global singleton |
1249 | static AnyEnumTypePtr get(); |
1250 | private: |
1251 | AnyEnumType() |
1252 | : Type(TypeKind::AnyEnumType) {} |
1253 | }; |
1254 | |
1255 | struct NumberType; |
1256 | using NumberTypePtr = SingletonTypePtr<NumberType>; |
1257 | // This type represents a Python number |
1258 | // Subtype hierarchy for Number Types (NumberType as the base type): |
1259 | // IntType <: NumberType |
1260 | // FloatType <: NumberType |
1261 | // ComplexType <:NumberType |
1262 | // |
1263 | // WARNING: if you add a new subtype of NumberType that is not |
1264 | // represented by a global singleton, you need to change NumberTypePtr |
1265 | // to a SingletonOrSharedTypePtr and deal with NumberType needing to |
1266 | // both inherit and not inherit from SharedType! |
1267 | struct TORCH_API NumberType : public Type { |
1268 | bool equals(const Type& rhs) const override; |
1269 | |
1270 | bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override; |
1271 | |
1272 | std::string str() const override { |
1273 | return "Scalar" ; // match what PythonArgParser says for clarity |
1274 | } |
1275 | static const TypeKind Kind = TypeKind::NumberType; |
1276 | // global singleton |
1277 | static NumberTypePtr get(); |
1278 | |
1279 | protected: |
1280 | NumberType(TypeKind kind = TypeKind::NumberType) : Type(kind) {} |
1281 | |
1282 | std::string annotation_str_impl(TypePrinter printer = nullptr) const override { |
1283 | (void)printer; // Suppress unused variable warning |
1284 | return "number" ; // technically not a valid python type, but |
1285 | // we need to use it when parsing back in annotations |
1286 | // for implicit conversions |
1287 | } |
1288 | }; |
1289 | |
1290 | struct FloatType; |
1291 | using FloatTypePtr = SingletonTypePtr<FloatType>; |
1292 | // This type represents a Python float number |
1293 | struct TORCH_API FloatType : public NumberType { |
1294 | bool equals(const Type& rhs) const override { |
1295 | return rhs.kind() == kind(); |
1296 | } |
1297 | std::string str() const override { |
1298 | return "float" ; |
1299 | } |
1300 | bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override { |
1301 | // NOLINTNEXTLINE(bugprone-parent-virtual-call) |
1302 | return rhs.kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not); |
1303 | } |
1304 | static const TypeKind Kind = TypeKind::FloatType; |
1305 | // global singleton |
1306 | static FloatTypePtr get(); |
1307 | |
1308 | private: |
1309 | FloatType() : NumberType(TypeKind::FloatType) {} |
1310 | std::string annotation_str_impl(TypePrinter printer = nullptr) const override { |
1311 | (void)printer; // Suppress unused variable warning |
1312 | return "float" ; |
1313 | } |
1314 | }; |
1315 | |
1316 | struct ComplexType; |
1317 | using ComplexTypePtr = SingletonTypePtr<ComplexType>; |
1318 | // This type represents a Python float number |
1319 | struct TORCH_API ComplexType : public NumberType { |
1320 | bool equals(const Type& rhs) const override { |
1321 | return rhs.kind() == kind(); |
1322 | } |
1323 | std::string str() const override { |
1324 | return "complex" ; |
1325 | } |
1326 | bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override { |
1327 | // NOLINTNEXTLINE(bugprone-parent-virtual-call) |
1328 | return rhs.kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not); |
1329 | } |
1330 | static const TypeKind Kind = TypeKind::ComplexType; |
1331 | // global singleton |
1332 | static ComplexTypePtr get(); |
1333 | |
1334 | private: |
1335 | ComplexType() : NumberType(TypeKind::ComplexType) {} |
1336 | std::string annotation_str_impl(TypePrinter printer = nullptr) const override { |
1337 | (void)printer; // Suppress unused variable warning |
1338 | return "complex" ; |
1339 | } |
1340 | }; |
1341 | |
1342 | // We need to introduce `SymIntType` to represent the `SymInt` type |
1343 | // used in function schemas e.g. `aten::narrow_copy(... SymInt length) |
1344 | // `SymInt` will be used to enable tracing arithmetic operations on |
1345 | // dimension values. Please see [SymInt.h] for more information |
1346 | struct SymIntType; |
1347 | using SymIntTypePtr = SingletonTypePtr<SymIntType>; |
1348 | struct TORCH_API SymIntType : public Type { |
1349 | bool equals(const Type& rhs) const override { |
1350 | return rhs.kind() == kind(); |
1351 | } |
1352 | std::string str() const override { |
1353 | return "SymInt" ; |
1354 | } |
1355 | std::string annotation_str_impl(TypePrinter printer = nullptr) const override { |
1356 | return "int" ; |
1357 | } |
1358 | static const TypeKind Kind = TypeKind::SymIntType; |
1359 | // global singleton |
1360 | static SymIntTypePtr get(); |
1361 | |
1362 | private: |
1363 | SymIntType() : Type(TypeKind::SymIntType) {} |
1364 | }; |
1365 | |
1366 | struct SymFloatType; |
1367 | using SymFloatTypePtr = SingletonTypePtr<SymFloatType>; |
1368 | struct TORCH_API SymFloatType : public Type { |
1369 | bool equals(const Type& rhs) const override { |
1370 | return rhs.kind() == kind(); |
1371 | } |
1372 | std::string str() const override { |
1373 | return "SymFloat" ; |
1374 | } |
1375 | std::string annotation_str_impl(TypePrinter printer = nullptr) const override { |
1376 | return "float" ; |
1377 | } |
1378 | static const TypeKind Kind = TypeKind::SymFloatType; |
1379 | // global singleton |
1380 | static SymFloatTypePtr get(); |
1381 | |
1382 | private: |
1383 | SymFloatType() : Type(TypeKind::SymFloatType) {} |
1384 | }; |
1385 | |
1386 | struct IntType; |
1387 | using IntTypePtr = SingletonTypePtr<IntType>; |
1388 | // This type represents a Python int number |
1389 | struct TORCH_API IntType : public NumberType { |
1390 | bool equals(const Type& rhs) const override { |
1391 | return rhs.kind() == kind(); |
1392 | } |
1393 | std::string str() const override { |
1394 | return "int" ; |
1395 | } |
1396 | bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override { |
1397 | // NOLINTNEXTLINE(bugprone-parent-virtual-call) |
1398 | return rhs.kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not); |
1399 | } |
1400 | static const TypeKind Kind = TypeKind::IntType; |
1401 | // global singleton |
1402 | static IntTypePtr get(); |
1403 | |
1404 | private: |
1405 | IntType() : NumberType(TypeKind::IntType) {} |
1406 | std::string annotation_str_impl(TypePrinter printer = nullptr) const override { |
1407 | (void)printer; // Suppress unused variable warning |
1408 | return "int" ; |
1409 | } |
1410 | }; |
1411 | |
1412 | struct BoolType; |
1413 | using BoolTypePtr = SingletonTypePtr<BoolType>; |
1414 | // This node represents a Python bool value |
1415 | struct TORCH_API BoolType : public Type { |
1416 | bool equals(const Type& rhs) const override { |
1417 | return rhs.kind() == kind(); |
1418 | } |
1419 | std::string str() const override { |
1420 | return "bool" ; |
1421 | } |
1422 | static const TypeKind Kind = TypeKind::BoolType; |
1423 | // global singleton |
1424 | static BoolTypePtr get(); |
1425 | |
1426 | private: |
1427 | BoolType() : Type(TypeKind::BoolType) {} |
1428 | }; |
1429 | |
1430 | struct StringType; |
1431 | using StringTypePtr = SingletonTypePtr<StringType>; |
1432 | // This type represents a Python string |
1433 | struct TORCH_API StringType : public Type { |
1434 | bool equals(const Type& rhs) const override { |
1435 | return rhs.kind() == kind(); |
1436 | } |
1437 | std::string str() const override { |
1438 | // we only use "str" (not "string") in both FunctionSchema and script |
1439 | return annotation_str(); |
1440 | } |
1441 | std::string annotation_str_impl(TypePrinter printer = nullptr) const override { |
1442 | (void)printer; // Suppress unused variable warning |
1443 | return "str" ; |
1444 | } |
1445 | static const TypeKind Kind = TypeKind::StringType; |
1446 | // global singleton |
1447 | static StringTypePtr get(); |
1448 | |
1449 | private: |
1450 | StringType() : Type(TypeKind::StringType) {} |
1451 | }; |
1452 | |
1453 | struct StorageType; |
1454 | using StorageTypePtr = SingletonTypePtr<StorageType>; |
1455 | struct TORCH_API StorageType : public Type { |
1456 | bool equals(const Type& rhs) const override { |
1457 | return rhs.kind() == kind(); |
1458 | } |
1459 | std::string str() const override { |
1460 | return annotation_str(); |
1461 | } |
1462 | std::string annotation_str_impl(TypePrinter printer = nullptr) const override { |
1463 | (void)printer; // Suppress unused variable warning |
1464 | return "Storage" ; |
1465 | } |
1466 | static const TypeKind Kind = TypeKind::StorageType; |
1467 | // global singleton |
1468 | static StorageTypePtr get(); |
1469 | |
1470 | private: |
1471 | StorageType() : Type(TypeKind::StorageType) {} |
1472 | }; |
1473 | |
1474 | struct FunctionType; |
1475 | using FunctionTypePtr = std::shared_ptr<FunctionType>; |
1476 | struct TORCH_API FunctionType : public NamedType { |
1477 | static FunctionTypePtr create(torch::jit::Function* function) { |
1478 | return FunctionTypePtr( |
1479 | new FunctionType(function)); // NOLINT(modernize-make-shared) |
1480 | } |
1481 | bool equals(const Type& rhs) const override { |
1482 | if (auto func_type = rhs.cast<FunctionType>()) { |
1483 | return func_type->function_ == function_; |
1484 | } |
1485 | |
1486 | return false; |
1487 | } |
1488 | std::string str() const override { |
1489 | return "Function" ; |
1490 | } |
1491 | torch::jit::Function* function() const { |
1492 | return function_; |
1493 | } |
1494 | static const TypeKind Kind = TypeKind::FunctionType; |
1495 | |
1496 | private: |
1497 | FunctionType(torch::jit::Function* function); |
1498 | std::string annotation_str_impl(TypePrinter printer = nullptr) const override { |
1499 | (void)printer; // Suppress unused variable warning |
1500 | const auto& n = name().value(); |
1501 | return n.qualifiedName(); |
1502 | } |
1503 | torch::jit::Function* function_; |
1504 | }; |
1505 | |
1506 | struct NoneType; |
1507 | using NoneTypePtr = SingletonTypePtr<NoneType>; |
1508 | // This type represents a Python None |
1509 | struct TORCH_API NoneType : public Type { |
1510 | bool equals(const Type& rhs) const override { |
1511 | return rhs.kind() == kind(); |
1512 | } |
1513 | std::string str() const override { |
1514 | return "NoneType" ; |
1515 | } |
1516 | bool isSubtypeOfExt(const Type& rhs, std::ostream *why_not) const override; |
1517 | |
1518 | static const TypeKind Kind = TypeKind::NoneType; |
1519 | // global singleton |
1520 | static NoneTypePtr get(); |
1521 | |
1522 | private: |
1523 | NoneType() : Type(TypeKind::NoneType) {} |
1524 | }; |
1525 | |
1526 | struct GeneratorType; |
1527 | using GeneratorTypePtr = SingletonTypePtr<GeneratorType>; |
1528 | // This type represents a Generator |
1529 | struct TORCH_API GeneratorType : public Type { |
1530 | bool equals(const Type& rhs) const override { |
1531 | return rhs.kind() == kind(); |
1532 | } |
1533 | std::string str() const override { |
1534 | return "Generator" ; |
1535 | } |
1536 | static const TypeKind Kind = TypeKind::GeneratorType; |
1537 | // global singleton |
1538 | static GeneratorTypePtr get(); |
1539 | |
1540 | private: |
1541 | GeneratorType() : Type(TypeKind::GeneratorType) {} |
1542 | }; |
1543 | |
1544 | struct QuantizerType; |
1545 | using QuantizerTypePtr = SingletonTypePtr<QuantizerType>; |
1546 | // This type represents a Quantizer |
1547 | struct TORCH_API QuantizerType : public Type { |
1548 | bool equals(const Type& rhs) const override { |
1549 | return rhs.kind() == kind(); |
1550 | } |
1551 | std::string str() const override { |
1552 | return "Quantizer" ; |
1553 | } |
1554 | static const TypeKind Kind = TypeKind::QuantizerType; |
1555 | // global singleton |
1556 | static QuantizerTypePtr get(); |
1557 | |
1558 | private: |
1559 | QuantizerType() : Type(TypeKind::QuantizerType) {} |
1560 | }; |
1561 | |
1562 | struct QSchemeType; |
1563 | using QSchemeTypePtr = SingletonTypePtr<QSchemeType>; |
1564 | // This type represents a QScheme |
1565 | struct TORCH_API QSchemeType : public Type { |
1566 | bool equals(const Type& rhs) const override { |
1567 | return rhs.kind() == kind(); |
1568 | } |
1569 | std::string str() const override { |
1570 | return "QScheme" ; |
1571 | } |
1572 | static const TypeKind Kind = TypeKind::QSchemeType; |
1573 | // global singleton |
1574 | static QSchemeTypePtr get(); |
1575 | |
1576 | private: |
1577 | QSchemeType() : Type(TypeKind::QSchemeType) {} |
1578 | }; |
1579 | |
1580 | struct DeviceObjType; |
1581 | using DeviceObjTypePtr = SingletonTypePtr<DeviceObjType>; |
1582 | // This type represents a Device |
1583 | struct TORCH_API DeviceObjType : public Type { |
1584 | bool equals(const Type& rhs) const override { |
1585 | return rhs.kind() == kind(); |
1586 | } |
1587 | std::string str() const override { |
1588 | return "Device" ; |
1589 | } |
1590 | static const TypeKind Kind = TypeKind::DeviceObjType; |
1591 | // global singleton |
1592 | static DeviceObjTypePtr get(); |
1593 | |
1594 | private: |
1595 | DeviceObjType() : Type(TypeKind::DeviceObjType) {} |
1596 | }; |
1597 | |
1598 | struct StreamObjType; |
1599 | using StreamObjTypePtr = SingletonTypePtr<StreamObjType>; |
1600 | // This type represents a Generator |
1601 | struct TORCH_API StreamObjType : public Type { |
1602 | bool equals(const Type& rhs) const override { |
1603 | return rhs.kind() == kind(); |
1604 | } |
1605 | std::string str() const override { |
1606 | return "Stream" ; |
1607 | } |
1608 | static const TypeKind Kind = TypeKind::StreamObjType; |
1609 | // global singleton |
1610 | static StreamObjTypePtr get(); |
1611 | |
1612 | private: |
1613 | StreamObjType() : Type(TypeKind::StreamObjType) {} |
1614 | }; |
1615 | |
1616 | struct VarType; |
1617 | using VarTypePtr = std::shared_ptr<VarType>; |
1618 | // This type represents a type variable, used in FunctionSchema |
1619 | struct VarType : public SharedType { |
1620 | static VarTypePtr create(std::string name_) { |
1621 | return VarTypePtr(new VarType(std::move(name_))); |
1622 | } |
1623 | bool equals(const Type& rhs) const override { |
1624 | return rhs.kind() == kind(); |
1625 | } |
1626 | std::string str() const override { |
1627 | return name(); |
1628 | } |
1629 | const std::string& name() const { |
1630 | return name_; |
1631 | } |
1632 | bool hasFreeVariables() const override { |
1633 | return true; |
1634 | } |
1635 | static const TypeKind Kind = TypeKind::VarType; |
1636 | |
1637 | private: |
1638 | VarType(std::string name_) |
1639 | : SharedType(TypeKind::VarType), name_(std::move(name_)) {} |
1640 | std::string name_; |
1641 | }; |
1642 | |
1643 | struct CapsuleType; |
1644 | using CapsuleTypePtr = SingletonTypePtr<CapsuleType>; |
1645 | // This type represents a Python Capsule. |
1646 | // It does not appear in the IR and is only used during runtime |
1647 | struct TORCH_API CapsuleType : public Type { |
1648 | bool equals(const Type& rhs) const override { |
1649 | return rhs.kind() == kind(); |
1650 | } |
1651 | std::string str() const override { |
1652 | return "Capsule" ; |
1653 | } |
1654 | static const TypeKind Kind = TypeKind::CapsuleType; |
1655 | // global singleton |
1656 | static CapsuleTypePtr get(); |
1657 | private: |
1658 | CapsuleType() |
1659 | : Type(TypeKind::CapsuleType) {} |
1660 | }; |
1661 | |
1662 | struct PyObjectType; |
1663 | using PyObjectTypePtr = SingletonTypePtr<PyObjectType>; |
1664 | // This type represents a PyObject Type |
1665 | struct TORCH_API PyObjectType : public Type { |
1666 | bool equals(const Type& rhs) const override { |
1667 | return rhs.kind() == kind(); |
1668 | } |
1669 | std::string str() const override { |
1670 | return "PyObject" ; |
1671 | } |
1672 | static const TypeKind Kind = TypeKind::PyObjectType; |
1673 | // global singleton |
1674 | static PyObjectTypePtr get(); |
1675 | private: |
1676 | PyObjectType() |
1677 | : Type(TypeKind::PyObjectType) {} |
1678 | }; |
1679 | |
1680 | enum class TypeVerbosity { |
1681 | None, |
1682 | Type, |
1683 | TypeAndStride, |
1684 | Full, |
1685 | Symbolic, |
1686 | Default = Full, |
1687 | }; |
1688 | |
1689 | TORCH_API TypeVerbosity type_verbosity(); |
1690 | |
1691 | TORCH_API std::ostream& operator<<(std::ostream& out, const Type& t); |
1692 | template <typename T> |
1693 | TORCH_API std::ostream& operator<<( |
1694 | std::ostream& out, |
1695 | const VaryingShape<T>& t); |
1696 | TORCH_API std::ostream& operator<<(std::ostream& os, const SymbolicShape& s); |
1697 | TORCH_API std::ostream& operator<<(std::ostream& os, const ShapeSymbol& s); |
1698 | TORCH_API std::ostream& operator<<(std::ostream& os, const Stride& s); |
1699 | // what is the type, ignoring extra size/shape information? |
1700 | // e.g. Tensor(2x3) -> Dynamic, and Tuple(Tensor(2x3),...) -> Tuple(Dynamic,...) |
1701 | |
1702 | // `unshapedType` is used to remove Tensor subtypes. We treat all Tensor |
1703 | // subtypes as simply "Tensor"; we also create a new version of any |
1704 | // container types in which internal Tensors have undergone the same |
1705 | // operation. This is used for type comparisons between two Tensor types |
1706 | // (`unshapedType` means that we don't falsely return `false` for e.g. |
1707 | // Tensors of different dimensions). It's also used in the alias |
1708 | // analysis pass. |
1709 | // Be careful with calls because this can be very slow. If calling this |
1710 | // on a graph, use `EraseShapeInformation` in shape_analysis.h |
1711 | inline TypePtr unshapedType(const TypePtr& type) { |
1712 | if (type->isSubtypeOf(*TensorType::get())) { |
1713 | return TensorType::get(); |
1714 | } |
1715 | at::ArrayRef<TypePtr> contained = type->containedTypes(); |
1716 | if (contained.empty()) { |
1717 | return type; |
1718 | } |
1719 | return type->withContained(fmap(type->containedTypes(), unshapedType)); |
1720 | } |
1721 | |
1722 | inline TypePtr TensorType::fromNumberType(const Type& typ) { |
1723 | if (typ.isSubtypeOf(*IntType::get())) { |
1724 | return TensorType::createContiguous(at::kLong, at::kCPU, {}); |
1725 | } else if (typ.isSubtypeOf(*FloatType::get())) { |
1726 | return TensorType::createContiguous(at::kDouble, at::kCPU, {}); |
1727 | } else if (typ.isSubtypeOf(*BoolType::get())) { |
1728 | return TensorType::createContiguous(at::kBool, at::kCPU, {}); |
1729 | } else if (typ.kind() == NumberType::Kind) { |
1730 | return TensorType::create(c10::nullopt, at::kCPU, {}, c10::nullopt); |
1731 | } |
1732 | TORCH_CHECK(false, "Unknown number type: " , typ.str()); |
1733 | } |
1734 | inline TypePtr TensorType::fromBoolType() { |
1735 | return TensorType::createContiguous(at::kBool, at::kCPU, {}); |
1736 | } |
1737 | |
1738 | inline c10::optional<c10::ScalarType> tryScalarTypeFromJitType(const Type& type) { |
1739 | if (type == *FloatType::get()) { |
1740 | return at::typeMetaToScalarType(c10::get_default_dtype()); |
1741 | } else if (type == *IntType::get()) { |
1742 | return at::ScalarType::Long; |
1743 | } else if (type == *BoolType::get()) { |
1744 | return at::ScalarType::Bool; |
1745 | } |
1746 | return c10::nullopt; |
1747 | } |
1748 | |
1749 | inline at::ScalarType scalarTypeFromJitType(const Type& type) { |
1750 | auto result = tryScalarTypeFromJitType(type); |
1751 | TORCH_CHECK( |
1752 | result, |
1753 | "Add new condition, expected Float, Complex, Int, or Bool but got" , |
1754 | type.str()); |
1755 | return *result; |
1756 | } |
1757 | |
1758 | // Attempt to find the correct supertype of the two types `t1` and `t2`. |
1759 | // If no supertype is found, then nullopt will be returned if |
1760 | // `default_to_union` is false, and `Union[t1, t2]` will be returned |
1761 | // if it is true. If `t1 == t2`, or `t1` is a type refinement of `t2`, |
1762 | // then `t2` will be returned (and vice versa). |
1763 | // |
1764 | // Two different tensortypes will return dynamic. |
1765 | // |
1766 | // Currently we chose not to support returning a NumberType for |
1767 | // two types from the set of {FloatType, IntType, ComplexType}, because |
1768 | // there is a lack of operator support for NumberType. |
1769 | // |
1770 | // If `type_hint` is an `InterfaceType`, then we can use that as a |
1771 | // potential supertype for `ClassType`s in the list. Otherwise, we have |
1772 | // no way to find and use some common interface type |
1773 | TORCH_API c10::optional<TypePtr> unifyTypes( |
1774 | const TypePtr& t1, |
1775 | const TypePtr& t2, |
1776 | bool default_to_union = false, |
1777 | TypePtr type_hint = nullptr); |
1778 | |
1779 | TORCH_API c10::optional<TypePtr> unifyTypeList( |
1780 | at::ArrayRef<TypePtr> elements, |
1781 | std::ostream& why_not, |
1782 | bool default_to_union = false, |
1783 | TypePtr type_hint = nullptr); |
1784 | |
1785 | namespace detail { |
1786 | template <typename T> |
1787 | struct getTypePtr_ final { |
1788 | static decltype(auto) call() { |
1789 | return ([]() { |
1790 | try { |
1791 | return getCustomClassType<T>(); |
1792 | } catch(const c10::Error&) { |
1793 | TORCH_CHECK( |
1794 | false, |
1795 | "Type " , |
1796 | c10::util::get_fully_qualified_type_name<T>(), |
1797 | " could not be converted to any of the known types." |
1798 | ); |
1799 | } |
1800 | }()); |
1801 | } |
1802 | }; |
1803 | |
1804 | template <typename T, bool fake> |
1805 | struct getMaybeFakeTypePtr_ final { |
1806 | static decltype(auto) call() { |
1807 | return getTypePtr_<T>::call(); |
1808 | } |
1809 | }; |
1810 | |
1811 | template <> |
1812 | struct getTypePtr_<at::IValue> final { |
1813 | static decltype(auto) call() { |
1814 | return AnyType::get(); |
1815 | } |
1816 | }; |
1817 | |
1818 | template <> |
1819 | struct getTypePtr_<at::Tensor> final { |
1820 | static decltype(auto) call() { |
1821 | return TensorType::get(); |
1822 | } |
1823 | }; |
1824 | template <> |
1825 | struct getTypePtr_<c10::Storage> final { |
1826 | static decltype(auto) call() { |
1827 | return StorageType::get(); |
1828 | } |
1829 | }; |
1830 | template <> |
1831 | struct getTypePtr_<c10::Stream> final { |
1832 | static decltype(auto) call() { |
1833 | return StreamObjType::get(); |
1834 | } |
1835 | }; |
1836 | template <> |
1837 | struct getTypePtr_<double> final { |
1838 | static decltype(auto) call() { |
1839 | return FloatType::get(); |
1840 | } |
1841 | }; |
1842 | template <> |
1843 | struct getTypePtr_<c10::complex<double>> final { |
1844 | static decltype(auto) call() { |
1845 | return ComplexType::get(); |
1846 | } |
1847 | }; |
1848 | template <> |
1849 | struct getTypePtr_<int64_t> final { |
1850 | static decltype(auto) call() { |
1851 | return IntType::get(); |
1852 | } |
1853 | }; |
1854 | |
1855 | template <> |
1856 | struct getMaybeFakeTypePtr_<SymInt, false> final { |
1857 | static decltype(auto) call() { |
1858 | return SymIntType::get(); |
1859 | } |
1860 | }; |
1861 | template <> |
1862 | struct getMaybeFakeTypePtr_<SymInt, true> final { |
1863 | static decltype(auto) call() { |
1864 | return IntType::get(); |
1865 | } |
1866 | }; |
1867 | |
1868 | template <> |
1869 | struct getMaybeFakeTypePtr_<SymFloat, false> final { |
1870 | static decltype(auto) call() { |
1871 | return SymFloatType::get(); |
1872 | } |
1873 | }; |
1874 | template <> |
1875 | struct getMaybeFakeTypePtr_<SymFloat, true> final { |
1876 | static decltype(auto) call() { |
1877 | return FloatType::get(); |
1878 | } |
1879 | }; |
1880 | |
1881 | template <> |
1882 | struct getTypePtr_<c10::Device> final { |
1883 | static decltype(auto) call() { |
1884 | return DeviceObjType::get(); |
1885 | } |
1886 | }; |
1887 | template <> |
1888 | struct getTypePtr_<bool> final { |
1889 | static decltype(auto) call() { |
1890 | return BoolType::get(); |
1891 | } |
1892 | }; |
1893 | template <> |
1894 | struct getTypePtr_<at::Scalar> final { |
1895 | static decltype(auto) call() { |
1896 | return NumberType::get(); |
1897 | } |
1898 | }; |
1899 | template <> |
1900 | struct getTypePtr_<c10::QScheme> final { |
1901 | static decltype(auto) call() { |
1902 | return QSchemeType::get(); |
1903 | } |
1904 | }; |
1905 | template <> |
1906 | struct getTypePtr_<at::Generator> final { |
1907 | static decltype(auto) call() { |
1908 | return TypeFactory::create<OptionalType>( |
1909 | TypeFactory::get<GeneratorType>()); |
1910 | } |
1911 | }; |
1912 | template <> |
1913 | struct getTypePtr_<std::string> final { |
1914 | static decltype(auto) call() { |
1915 | return StringType::get(); |
1916 | } |
1917 | }; |
1918 | template <> |
1919 | struct getTypePtr_<c10::string_view> final { |
1920 | static decltype(auto) call() { |
1921 | return StringType::get(); |
1922 | } |
1923 | }; |
1924 | template <> |
1925 | struct getTypePtr_<at::Dimname> final { |
1926 | static decltype(auto) call() { |
1927 | return StringType::get(); |
1928 | } |
1929 | }; |
1930 | template <class T, bool fake> |
1931 | struct getMaybeFakeTypePtr_<std::vector<T>, fake> final { |
1932 | static const auto& call() { |
1933 | static auto inner_type = getMaybeFakeTypePtr_<T, fake>::call(); |
1934 | // The "per vector<T>" static singleton needs to live in a .cpp file, |
1935 | // otherwise we'll end up with one singleton instance per shared library. |
1936 | static auto type = ListType::get("vector" , inner_type); |
1937 | return type; |
1938 | } |
1939 | }; |
1940 | template <class T, bool fake> |
1941 | struct getMaybeFakeTypePtr_<c10::ArrayRef<T>, fake> final { |
1942 | static const auto& call() { |
1943 | static auto inner_type = getMaybeFakeTypePtr_<T, fake>::call(); |
1944 | // The "per ArrayRef<T>" static singleton needs to live in a .cpp file, |
1945 | // otherwise we'll end up with one singleton instance per shared library. |
1946 | static auto type = ListType::get("ArrayRef" , inner_type); |
1947 | return type; |
1948 | } |
1949 | }; |
1950 | template <bool fake> |
1951 | struct getMaybeFakeTypePtr_<c10::SymIntArrayRef, fake> final { |
1952 | static const auto& call() { |
1953 | static auto type = ListType::create(getMaybeFakeTypePtr_<c10::SymInt, fake>::call()); |
1954 | return type; |
1955 | } |
1956 | }; |
1957 | template <class T, bool fake> |
1958 | struct getMaybeFakeTypePtr_<c10::List<T>, fake> final { |
1959 | static const auto& call() { |
1960 | static auto inner_type = getMaybeFakeTypePtr_<T, fake>::call(); |
1961 | // The "per List<T>" static singleton needs to live in a .cpp file, |
1962 | // otherwise we'll end up with one singleton instance per shared library. |
1963 | static auto type = ListType::get("List" , inner_type); |
1964 | return type; |
1965 | } |
1966 | }; |
1967 | template <class T, bool fake> |
1968 | struct getMaybeFakeTypePtr_<c10::IListRef<T>, fake> final { |
1969 | static const auto& call() { |
1970 | static auto inner_type = getMaybeFakeTypePtr_<T, fake>::call(); |
1971 | static auto type = ListType::get("List" , inner_type); |
1972 | return type; |
1973 | } |
1974 | }; |
1975 | template <class T, size_t N, bool fake> |
1976 | struct getMaybeFakeTypePtr_<std::array<T, N>, fake> final { |
1977 | static const auto& call() { |
1978 | static auto inner_type = getMaybeFakeTypePtr_<T, fake>::call(); |
1979 | // The "per array<T, N>" static singleton needs to live in a .cpp file, |
1980 | // otherwise we'll end up with one singleton instance per shared library. |
1981 | // (Concatenating the length onto the end of the string because we want a unique |
1982 | // type_ptr created for every std::array<T, N> type). |
1983 | static auto type = ListType::get(std::string("array" ) + std::to_string(N), inner_type); |
1984 | return type; |
1985 | } |
1986 | }; |
1987 | template <class K, class V, bool fake> |
1988 | struct getMaybeFakeTypePtr_<std::unordered_map<K, V>, fake> final { |
1989 | static const auto& call() { |
1990 | static auto inner_key_type = getMaybeFakeTypePtr_<K, fake>::call(); |
1991 | static auto inner_val_type = getMaybeFakeTypePtr_<V, fake>::call(); |
1992 | // The "per unordered_map<K, V>" static singleton needs to live in a .cpp file, |
1993 | // otherwise we'll end up with one singleton instance per shared library. |
1994 | static auto type = DictType::get("unordered_map" , inner_key_type, inner_val_type); |
1995 | return type; |
1996 | } |
1997 | }; |
1998 | template <class K, class V, bool fake> |
1999 | struct getMaybeFakeTypePtr_<c10::Dict<K, V>, fake> final { |
2000 | static const auto& call() { |
2001 | static auto inner_key_type = getMaybeFakeTypePtr_<K, fake>::call(); |
2002 | static auto inner_val_type = getMaybeFakeTypePtr_<V, fake>::call(); |
2003 | // The "per Dict<K, V>" static singleton needs to live in a .cpp file, |
2004 | // otherwise we'll end up with one singleton instance per shared library. |
2005 | static auto type = DictType::get("Dict" , inner_key_type, inner_val_type); |
2006 | return type; |
2007 | } |
2008 | }; |
2009 | |
2010 | template <class T, bool fake> |
2011 | struct getMaybeFakeTypePtr_<at::optional<T>, fake> final { |
2012 | static const auto& call() { |
2013 | static auto inner_type = getMaybeFakeTypePtr_<T, fake>::call(); |
2014 | // The "per optional<T>" static singleton needs to live in a .cpp file, |
2015 | // otherwise we'll end up with one singleton instance per shared library. |
2016 | static auto type = OptionalType::get(inner_type); |
2017 | return type; |
2018 | } |
2019 | }; |
2020 | |
2021 | |
2022 | template<> |
2023 | struct getTypePtr_<at::OptionalIntArrayRef> final { |
2024 | static const auto& call() { |
2025 | static auto inner_type = getMaybeFakeTypePtr_<IntArrayRef, false>::call(); |
2026 | // The "per optional<T>" static singleton needs to live in a .cpp file, |
2027 | // otherwise we'll end up with one singleton instance per shared library. |
2028 | static auto type = OptionalType::get(inner_type); |
2029 | return type; |
2030 | } |
2031 | }; |
2032 | |
2033 | template <bool fake> |
2034 | struct getMaybeFakeTypePtr_<at::OptionalSymIntArrayRef, fake> final { |
2035 | static const auto& call() { |
2036 | // The "per optional<T>" static singleton needs to live in a .cpp file, |
2037 | // otherwise we'll end up with one singleton instance per shared library. |
2038 | static auto inner_type = getMaybeFakeTypePtr_<SymIntArrayRef, fake>::call(); |
2039 | static auto type = OptionalType::get(inner_type); |
2040 | return type; |
2041 | } |
2042 | }; |
2043 | |
2044 | template <class... Contained, bool fake> |
2045 | struct getMaybeFakeTypePtr_<std::tuple<Contained...>, fake> final { |
2046 | static const auto& call() { |
2047 | static auto type = ([]() { |
2048 | std::vector<TypePtr> contained_types = { |
2049 | (getMaybeFakeTypePtr_<Contained, fake>::call())... |
2050 | }; |
2051 | return TupleType::create(std::move(contained_types)); |
2052 | })(); |
2053 | return type; |
2054 | } |
2055 | }; |
2056 | template <> |
2057 | struct getTypePtr_<void> final { |
2058 | static decltype(auto) call() { |
2059 | return NoneType::get(); |
2060 | } |
2061 | }; |
2062 | } // namespace detail |
2063 | template <class T> |
2064 | inline decltype(auto) getTypePtr() { |
2065 | // TODO: static_assert that a templated function exists, and throw a friendly |
2066 | // error message if not |
2067 | return detail::getMaybeFakeTypePtr_<T, false>::call(); |
2068 | } |
2069 | |
2070 | template <class T> |
2071 | inline TypePtr getTypePtrCopy() { |
2072 | // TODO: static_assert that a templated function exists, and throw a friendly |
2073 | // error message if not |
2074 | return getTypePtr<T>(); |
2075 | } |
2076 | |
2077 | template <class T> |
2078 | inline decltype(auto) getFakeTypePtr() { |
2079 | return detail::getMaybeFakeTypePtr_<T, true>::call(); |
2080 | } |
2081 | |
2082 | template <class T> |
2083 | inline TypePtr getFakeTypePtrCopy() { |
2084 | return getFakeTypePtr<T>(); |
2085 | } |
2086 | |
2087 | using TypeEnv = std::unordered_map<std::string, TypePtr>; |
2088 | struct MatchTypeReturn { |
2089 | MatchTypeReturn(std::string reason) : reason_(std::move(reason)) {} |
2090 | static MatchTypeReturn Success() { |
2091 | return MatchTypeReturn(); |
2092 | } |
2093 | bool success() const { |
2094 | return !reason_.has_value(); |
2095 | } |
2096 | const std::string& reason() const { |
2097 | return reason_.value(); |
2098 | } |
2099 | |
2100 | private: |
2101 | MatchTypeReturn() |
2102 | : reason_(c10::nullopt) {} |
2103 | c10::optional<std::string> reason_; // is there is no match, this contains the reason |
2104 | }; |
2105 | |
2106 | // attempt to match the type variables in formal to actual, adding them to type_env. |
2107 | // If no match is possible this returns a MatchTypeReturn with r.success() == false |
2108 | // and a r.reason() that describes why it could not match. |
2109 | // note: It is possible to successfully match a formal, but for type variables |
2110 | // in the formal to still not be defined. In particular, None matches Optional[T] |
2111 | // but does not define the value of T. |
2112 | TORCH_API MatchTypeReturn |
2113 | matchTypeVariables(const TypePtr& formal, const TypePtr& actual, TypeEnv& type_env); |
2114 | |
2115 | // replace type variables appearing in `type` with the values in |
2116 | // `type_env`. Returns nullptr if a variable used in `type` |
2117 | // does not appear in `type_env` |
2118 | TORCH_API TypePtr tryEvalTypeVariables(const TypePtr& type, TypeEnv& type_env); |
2119 | |
2120 | TORCH_API bool elementTypeCanBeInferredFromMembers(const TypePtr& elem_type); |
2121 | |
2122 | struct InterfaceType; |
2123 | using InterfaceTypePtr = std::shared_ptr<InterfaceType>; |
2124 | |
2125 | // Interfaces are a list of abstract methods that a class might meet. |
2126 | // If a class provides those methods, it implicitly meets the interface. |
2127 | |
2128 | // Subtype relations for Interface with ClassType: |
2129 | // lhs (ClassType or InterfaceType) is a subtype of rhs if: |
2130 | // 1. lhs methods are a superset of rhs methods |
2131 | // 2. if rhs is module interface, the lhs must be module interface or module itself |
2132 | struct TORCH_API InterfaceType : public NamedType { |
2133 | static InterfaceTypePtr create( |
2134 | QualifiedName qualifiedName, bool is_module=false); |
2135 | |
2136 | bool equals(const Type& rhs) const override { |
2137 | if (auto user_rhs = rhs.castRaw<InterfaceType>()) { |
2138 | return isSubTypeImpl(*this, *user_rhs, nullptr) && |
2139 | isSubTypeImpl(*user_rhs, *this, nullptr); |
2140 | } |
2141 | return false; |
2142 | } |
2143 | |
2144 | std::string str() const override { |
2145 | return std::string("InterfaceType<" ) + name()->name() + ">" ; |
2146 | } |
2147 | |
2148 | bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override; |
2149 | |
2150 | // try to find a method of this interface, |
2151 | // returns nullptr if not found. |
2152 | const FunctionSchema* getMethod(const std::string& name) const; |
2153 | void addMethod(FunctionSchema schema); |
2154 | const std::vector<FunctionSchema>& methods() const { |
2155 | return *methods_; |
2156 | } |
2157 | |
2158 | bool is_module() const override{ |
2159 | return is_module_; |
2160 | } |
2161 | static const TypeKind Kind = TypeKind::InterfaceType; |
2162 | ~InterfaceType() override; |
2163 | private: |
2164 | InterfaceType(QualifiedName name, bool is_module); |
2165 | static bool isSubTypeImpl( |
2166 | const InterfaceType& lhs, |
2167 | const InterfaceType& rhs, |
2168 | std::ostream* why_not); |
2169 | |
2170 | std::string annotation_str_impl(TypePrinter printer = nullptr) const override { |
2171 | (void)printer; // Suppress unused variable warning |
2172 | return name()->qualifiedName(); |
2173 | } |
2174 | |
2175 | // shared_ptr so that this header does not have to depend on |
2176 | // FunctionSchema.h |
2177 | std::shared_ptr<std::vector<FunctionSchema>> methods_; |
2178 | // flag to distinguish if it's an interface type from a module or not |
2179 | bool is_module_; |
2180 | }; |
2181 | |
2182 | template <TypeKind K> |
2183 | struct EnumerationType : public Type { |
2184 | static const TypeKind Kind = K; |
2185 | |
2186 | bool equals(const Type& rhs) const override { |
2187 | return rhs.kind() == kind(); |
2188 | } |
2189 | |
2190 | protected: |
2191 | EnumerationType() : Type(Kind) {} |
2192 | }; |
2193 | |
2194 | // WARNING: These enumeration types below DO NOT actually get parsed out |
2195 | // from the logical schema strings, instead they are mapped as ints. To |
2196 | // observe these types, use real_type() instead of type() on Argument |
2197 | |
2198 | struct ScalarTypeType; |
2199 | using ScalarTypeTypePtr = SingletonTypePtr<ScalarTypeType>; |
2200 | struct TORCH_API ScalarTypeType : public EnumerationType<TypeKind::ScalarTypeType> { |
2201 | std::string str() const override { |
2202 | return "ScalarType" ; |
2203 | } |
2204 | static const TypeKind Kind = TypeKind::ScalarTypeType; |
2205 | // global singleton |
2206 | static ScalarTypeTypePtr get(); |
2207 | |
2208 | private: |
2209 | ScalarTypeType() : EnumerationType() {} |
2210 | }; |
2211 | |
2212 | struct MemoryFormatType; |
2213 | using MemoryFormatTypePtr = SingletonTypePtr<MemoryFormatType>; |
2214 | struct TORCH_API MemoryFormatType : public EnumerationType<TypeKind::MemoryFormatType> { |
2215 | std::string str() const override { |
2216 | return "MemoryFormat" ; |
2217 | } |
2218 | static const TypeKind Kind = TypeKind::MemoryFormatType; |
2219 | // global singleton |
2220 | static MemoryFormatTypePtr get(); |
2221 | |
2222 | private: |
2223 | MemoryFormatType() : EnumerationType() {} |
2224 | }; |
2225 | |
2226 | struct LayoutType; |
2227 | using LayoutTypePtr = SingletonTypePtr<LayoutType>; |
2228 | struct TORCH_API LayoutType : public EnumerationType<TypeKind::LayoutType> { |
2229 | std::string str() const override { |
2230 | return "Layout" ; |
2231 | } |
2232 | static const TypeKind Kind = TypeKind::LayoutType; |
2233 | // global singleton |
2234 | static LayoutTypePtr get(); |
2235 | |
2236 | private: |
2237 | LayoutType() : EnumerationType() {} |
2238 | }; |
2239 | |
2240 | namespace detail { |
2241 | template <> |
2242 | struct getMaybeFakeTypePtr_<c10::ScalarType, false> final { |
2243 | static decltype(auto) call() { |
2244 | return ScalarTypeType::get(); |
2245 | } |
2246 | }; |
2247 | template <> |
2248 | struct getMaybeFakeTypePtr_<c10::Layout, false> final { |
2249 | static decltype(auto) call() { |
2250 | return LayoutType::get(); |
2251 | } |
2252 | }; |
2253 | template <> |
2254 | struct getMaybeFakeTypePtr_<c10::MemoryFormat, false> final { |
2255 | static decltype(auto) call() { |
2256 | return MemoryFormatType::get(); |
2257 | } |
2258 | }; |
2259 | template <> |
2260 | struct getMaybeFakeTypePtr_<c10::ScalarType, true> final { |
2261 | static decltype(auto) call() { |
2262 | return IntType::get(); |
2263 | } |
2264 | }; |
2265 | template <> |
2266 | struct getMaybeFakeTypePtr_<c10::Layout, true> final { |
2267 | static decltype(auto) call() { |
2268 | return IntType::get(); |
2269 | } |
2270 | }; |
2271 | template <> |
2272 | struct getMaybeFakeTypePtr_<c10::MemoryFormat, true> final { |
2273 | static decltype(auto) call() { |
2274 | return IntType::get(); |
2275 | } |
2276 | }; |
2277 | } // namespace detail |
2278 | |
2279 | // the common supertype of all lists, |
2280 | // List[T] <: AnyList for all T |
2281 | struct AnyListType; |
2282 | using AnyListTypePtr = SingletonTypePtr<AnyListType>; |
2283 | struct TORCH_API AnyListType : public Type { |
2284 | bool equals(const Type& rhs) const override { |
2285 | return rhs.kind() == kind(); |
2286 | } |
2287 | std::string str() const override { |
2288 | return "list" ; |
2289 | } |
2290 | static const TypeKind Kind = TypeKind::AnyListType; |
2291 | // global singleton |
2292 | static AnyListTypePtr get(); |
2293 | private: |
2294 | AnyListType() |
2295 | : Type(TypeKind::AnyListType) {} |
2296 | }; |
2297 | |
2298 | // the common supertype of all tuples, |
2299 | // Tuple[T...] <: AnyTuple for all T |
2300 | struct AnyTupleType; |
2301 | using AnyTupleTypePtr = SingletonTypePtr<AnyTupleType>; |
2302 | struct TORCH_API AnyTupleType : public Type { |
2303 | bool equals(const Type& rhs) const override { |
2304 | return rhs.kind() == kind(); |
2305 | } |
2306 | |
2307 | std::string str() const override { |
2308 | return "tuple" ; |
2309 | } |
2310 | static const TypeKind Kind = TypeKind::AnyTupleType; |
2311 | |
2312 | // global singleton |
2313 | static AnyTupleTypePtr get(); |
2314 | private: |
2315 | AnyTupleType() |
2316 | : Type(TypeKind::AnyTupleType) {} |
2317 | }; |
2318 | |
2319 | // the common supertype of all classes, |
2320 | // ClassType <: AnyClassType for all classes |
2321 | struct AnyClassType; |
2322 | using AnyClassTypePtr = SingletonTypePtr<AnyClassType>; |
2323 | struct TORCH_API AnyClassType : public Type { |
2324 | bool equals(const Type& rhs) const override { |
2325 | return rhs.kind() == kind(); |
2326 | } |
2327 | std::string str() const override { |
2328 | return "AnyClassType" ; |
2329 | } |
2330 | static const TypeKind Kind = TypeKind::AnyClassType; |
2331 | // global singleton |
2332 | static AnyClassTypePtr get(); |
2333 | private: |
2334 | AnyClassType() |
2335 | : Type(TypeKind::AnyClassType) {} |
2336 | }; |
2337 | |
2338 | template<> |
2339 | inline typename detail::CastReturnType<NamedType>::type Type::cast<NamedType>() { |
2340 | if (kind() == TypeKind::TupleType || kind() == TypeKind::FunctionType || |
2341 | kind() == TypeKind::ClassType || kind() == TypeKind::InterfaceType) { |
2342 | return std::static_pointer_cast<NamedType>(static_cast<NamedType *>(this)->shared_from_this()); |
2343 | } |
2344 | return nullptr; |
2345 | } |
2346 | |
2347 | template<> |
2348 | inline typename detail::CastConstReturnType<NamedType>::type Type::cast<NamedType>() const { |
2349 | if (kind() == TypeKind::TupleType || kind() == TypeKind::FunctionType || |
2350 | kind() == TypeKind::ClassType || kind() == TypeKind::InterfaceType) { |
2351 | return std::static_pointer_cast<const NamedType>(static_cast<const NamedType *>(this)->shared_from_this()); |
2352 | } |
2353 | return nullptr; |
2354 | } |
2355 | |
2356 | template<> |
2357 | inline const NamedType* Type::castRaw<NamedType>() const { |
2358 | if (kind() == TypeKind::TupleType || kind() == TypeKind::FunctionType || |
2359 | kind() == TypeKind::ClassType || kind() == TypeKind::InterfaceType) { |
2360 | return static_cast<const NamedType*>(this); |
2361 | } |
2362 | return nullptr; |
2363 | } |
2364 | |
2365 | // Used as a return type when inferring the IValue type of a Python object. |
2366 | struct InferredType { |
2367 | /* implicit */ InferredType(TypePtr type) : type_(std::move(type)) {} |
2368 | /* implicit */ InferredType(std::string reason) |
2369 | : type_(nullptr), reason_(std::move(reason)) {} |
2370 | TypePtr type() const { |
2371 | TORCH_INTERNAL_ASSERT( |
2372 | type_, |
2373 | "Tried to get the type from an InferredType but the type is null. " , |
2374 | "Reason: " , |
2375 | reason_); |
2376 | return type_; |
2377 | } |
2378 | bool success() const { |
2379 | return type_ != nullptr; |
2380 | } |
2381 | const std::string& reason() const { |
2382 | TORCH_INTERNAL_ASSERT(!type_); |
2383 | return reason_; |
2384 | } |
2385 | |
2386 | private: |
2387 | TypePtr type_; |
2388 | std::string reason_; |
2389 | }; |
2390 | |
2391 | TORCH_API bool containsAnyType(const TypePtr& type); |
2392 | |
2393 | } // namespace c10 |
2394 | |