1#pragma once
2#include <ATen/core/List.h>
3#include <ATen/core/Tensor.h>
4#include <c10/core/impl/TorchDispatchModeTLS.h>
5
6#ifndef AT_PER_OPERATOR_HEADERS
7#include <ATen/Functions.h>
8#else
9#include <ATen/ops/equal.h>
10#endif
11
12namespace at {
13
14// Note [Tensor-subclass-like Tensors]
15// Tensor-subclass-like is defined as:
16// - a Tensor subclass (via __torch_dispatch__ in Python or extending
17// TensorImpl in C++)
18// - anything else that shares the same perils as Tensor subclasses.
19// For example, many Tensor subclasses do not have storage and meta Tensors
20// do not have storage either, so meta Tensors belong here.
21//
22// We should ensure that PyTorch internals supports Tensor-subclass-like
23// objects. In particular, Tensor-subclass-like objects struggle with two
24// classes of operations that are problematic for Tensor subclasses:
25// 1. Because some Tensor subclasses do not have storage, .item() or
26// .data_ptr() calls are not good.
27// 2. Certain in-place operations can eliminate the typing of the Tensor
28// subclass. For example:
29// >>> torch.zeros(input.sizes(), grad.options()).diag().copy_(input)
30// If input is a Tensor subclass, then the above ends up either erroring out
31// or returning a regular non-Tensor-subclass Tensor!
32
33constexpr auto kFunctorchWrappedTensors = DispatchKeySet(
34 {DispatchKey::FuncTorchGradWrapper,
35 DispatchKey::FuncTorchBatched,
36 DispatchKey::Functionalize});
37
38constexpr auto kTensorSubclassLike =
39 kFunctorchWrappedTensors |
40 DispatchKeySet(
41 {// WARNING: DO NOT put combined backend component + functionality keys
42 // here, you will incorrectly always match on the functionality key
43 // no matter the backend component
44 DispatchKey::Batched,
45 DispatchKey::Sparse,
46 DispatchKey::SparseCsrCPU,
47 DispatchKey::SparseCsrCUDA,
48 DispatchKey::Python}) |
49 DispatchKeySet(BackendComponent::MetaBit);
50
51inline bool isTensorSubclassLike(const Tensor& tensor) {
52 if (c10::impl::dispatch_mode_enabled())
53 return true;
54 auto key_set = tensor.unsafeGetTensorImpl()->key_set();
55 return !(key_set & kTensorSubclassLike).empty();
56}
57
58inline bool areAnyTensorSubclassLike(TensorList tensors) {
59 if (c10::impl::dispatch_mode_enabled())
60 return true;
61 return std::any_of(tensors.begin(), tensors.end(), isTensorSubclassLike);
62}
63
64inline bool areAnyOptionalTensorSubclassLike(
65 const c10::List<c10::optional<Tensor>>& tensors) {
66 if (c10::impl::dispatch_mode_enabled())
67 return true;
68 return std::any_of(
69 tensors.begin(), tensors.end(), [](const optional<Tensor>& opt_tensor) {
70 return (
71 opt_tensor.has_value() && isTensorSubclassLike(opt_tensor.value()));
72 });
73}
74
75// Helper function to deal testing truthfulness of a scalar tensor
76// in a Composite Compliant manner.
77// NOTE: This function expects a scalar tensor of boolean dtype.
78// Eg.
79// Non-Composite Compliant Pattern : (t == 0).all().item<bool>()
80// Composite Compliant Patter : is_salar_tensor_true((t == 0).all())
81inline bool is_scalar_tensor_true(const Tensor& t) {
82 TORCH_INTERNAL_ASSERT(t.dim() == 0)
83 TORCH_INTERNAL_ASSERT(t.scalar_type() == kBool)
84 return at::equal(t, t.new_ones({}, t.options()));
85}
86
87} // namespace at
88