1 | #pragma once |
2 | #include <ATen/core/List.h> |
3 | #include <ATen/core/Tensor.h> |
4 | #include <c10/core/impl/TorchDispatchModeTLS.h> |
5 | |
6 | #ifndef AT_PER_OPERATOR_HEADERS |
7 | #include <ATen/Functions.h> |
8 | #else |
9 | #include <ATen/ops/equal.h> |
10 | #endif |
11 | |
12 | namespace at { |
13 | |
14 | // Note [Tensor-subclass-like Tensors] |
15 | // Tensor-subclass-like is defined as: |
16 | // - a Tensor subclass (via __torch_dispatch__ in Python or extending |
17 | // TensorImpl in C++) |
18 | // - anything else that shares the same perils as Tensor subclasses. |
19 | // For example, many Tensor subclasses do not have storage and meta Tensors |
20 | // do not have storage either, so meta Tensors belong here. |
21 | // |
22 | // We should ensure that PyTorch internals supports Tensor-subclass-like |
23 | // objects. In particular, Tensor-subclass-like objects struggle with two |
24 | // classes of operations that are problematic for Tensor subclasses: |
25 | // 1. Because some Tensor subclasses do not have storage, .item() or |
26 | // .data_ptr() calls are not good. |
27 | // 2. Certain in-place operations can eliminate the typing of the Tensor |
28 | // subclass. For example: |
29 | // >>> torch.zeros(input.sizes(), grad.options()).diag().copy_(input) |
30 | // If input is a Tensor subclass, then the above ends up either erroring out |
31 | // or returning a regular non-Tensor-subclass Tensor! |
32 | |
33 | constexpr auto kFunctorchWrappedTensors = DispatchKeySet( |
34 | {DispatchKey::FuncTorchGradWrapper, |
35 | DispatchKey::FuncTorchBatched, |
36 | DispatchKey::Functionalize}); |
37 | |
38 | constexpr auto kTensorSubclassLike = |
39 | kFunctorchWrappedTensors | |
40 | DispatchKeySet( |
41 | {// WARNING: DO NOT put combined backend component + functionality keys |
42 | // here, you will incorrectly always match on the functionality key |
43 | // no matter the backend component |
44 | DispatchKey::Batched, |
45 | DispatchKey::Sparse, |
46 | DispatchKey::SparseCsrCPU, |
47 | DispatchKey::SparseCsrCUDA, |
48 | DispatchKey::Python}) | |
49 | DispatchKeySet(BackendComponent::MetaBit); |
50 | |
51 | inline bool isTensorSubclassLike(const Tensor& tensor) { |
52 | if (c10::impl::dispatch_mode_enabled()) |
53 | return true; |
54 | auto key_set = tensor.unsafeGetTensorImpl()->key_set(); |
55 | return !(key_set & kTensorSubclassLike).empty(); |
56 | } |
57 | |
58 | inline bool areAnyTensorSubclassLike(TensorList tensors) { |
59 | if (c10::impl::dispatch_mode_enabled()) |
60 | return true; |
61 | return std::any_of(tensors.begin(), tensors.end(), isTensorSubclassLike); |
62 | } |
63 | |
64 | inline bool areAnyOptionalTensorSubclassLike( |
65 | const c10::List<c10::optional<Tensor>>& tensors) { |
66 | if (c10::impl::dispatch_mode_enabled()) |
67 | return true; |
68 | return std::any_of( |
69 | tensors.begin(), tensors.end(), [](const optional<Tensor>& opt_tensor) { |
70 | return ( |
71 | opt_tensor.has_value() && isTensorSubclassLike(opt_tensor.value())); |
72 | }); |
73 | } |
74 | |
75 | // Helper function to deal testing truthfulness of a scalar tensor |
76 | // in a Composite Compliant manner. |
77 | // NOTE: This function expects a scalar tensor of boolean dtype. |
78 | // Eg. |
79 | // Non-Composite Compliant Pattern : (t == 0).all().item<bool>() |
80 | // Composite Compliant Patter : is_salar_tensor_true((t == 0).all()) |
81 | inline bool is_scalar_tensor_true(const Tensor& t) { |
82 | TORCH_INTERNAL_ASSERT(t.dim() == 0) |
83 | TORCH_INTERNAL_ASSERT(t.scalar_type() == kBool) |
84 | return at::equal(t, t.new_ones({}, t.options())); |
85 | } |
86 | |
87 | } // namespace at |
88 | |