1 | #pragma once |
2 | |
3 | #include <ATen/ExpandUtils.h> |
4 | #include <ATen/NestedTensorImpl.h> |
5 | #include <ATen/core/Tensor.h> |
6 | #include <c10/core/Device.h> |
7 | #include <c10/core/DeviceType.h> |
8 | #include <c10/core/Stream.h> |
9 | #include <c10/core/SymIntArrayRef.h> |
10 | #include <c10/core/TensorImpl.h> |
11 | #include <c10/core/impl/DeviceGuardImplInterface.h> |
12 | #include <c10/util/DimVector.h> |
13 | #include <c10/util/Exception.h> |
14 | #include <c10/util/SmallVector.h> |
15 | #include <c10/util/variant.h> |
16 | |
17 | #ifndef AT_PER_OPERATOR_HEADERS |
18 | #include <ATen/Functions.h> |
19 | #else |
20 | #include <ATen/ops/zeros.h> |
21 | #endif |
22 | |
23 | #include <cstdint> |
24 | #include <utility> |
25 | |
26 | namespace torch { |
27 | namespace autograd { |
28 | |
29 | using SymIntSmallVec = c10::SmallVector<c10::SymInt, c10::kDimVectorStaticSize>; |
30 | using MetadataShape = c10::variant<SymIntSmallVec, at::Tensor>; |
31 | |
32 | /** |
33 | * Records TensorOptions, shape of the tensor, whether or not the Python |
34 | * dispatch key is set (tensor subclass), and, where applicable, the stream the |
35 | * corresponding operation took place on. |
36 | * |
37 | * If is_valid() is false, then the corresponding input is not used and may be |
38 | * an undefined tensor. |
39 | */ |
40 | struct InputMetadata { |
41 | InputMetadata() = default; |
42 | |
43 | InputMetadata( |
44 | const at::TensorOptions options, |
45 | MetadataShape input_shape, |
46 | bool is_tensor_subclass) |
47 | : options_{options}, |
48 | shape_{std::move(input_shape)}, |
49 | is_tensor_subclass_{is_tensor_subclass} { |
50 | auto device_ = options.device(); |
51 | stream_ = c10::impl::getDeviceGuardImpl(device_.type())->getStream(device_); |
52 | } |
53 | |
54 | InputMetadata(const at::Tensor& t) |
55 | : InputMetadata( |
56 | t.options(), |
57 | compute_variant_shape(t), |
58 | t.unsafeGetTensorImpl()->is_python_dispatch()) {} |
59 | |
60 | const at::TensorOptions options() const { |
61 | return options_; |
62 | } |
63 | |
64 | caffe2::TypeMeta dtype() const { |
65 | return options_.dtype(); |
66 | } |
67 | |
68 | at::Device device() const { |
69 | return options_.device(); |
70 | } |
71 | |
72 | at::Layout layout() const { |
73 | return options_.layout(); |
74 | } |
75 | |
76 | c10::Stream stream() const { |
77 | return stream_; |
78 | } |
79 | |
80 | bool is_tensor_subclass() const { |
81 | return is_tensor_subclass_; |
82 | } |
83 | |
84 | at::Tensor zeros_like() const { |
85 | TORCH_CHECK( |
86 | !is_nested_tensor(), |
87 | "Zeros is not currently supported for nested tensors." ) |
88 | return at::zeros_symint(shape_as_dim_vector(), options_); |
89 | } |
90 | |
91 | bool is_same_shape(const at::Tensor& grad) const { |
92 | TORCH_CHECK( |
93 | grad.is_nested() == is_nested_tensor(), |
94 | "Both grad and InputMetadata need to be either nested or non nested tensors." ) |
95 | if (grad.is_nested()) { |
96 | return at::native::get_nested_size_tensor(grad).is_same_size( |
97 | shape_as_tensor()); |
98 | } |
99 | return grad.sym_sizes().equals(shape_as_dim_vector()); |
100 | } |
101 | bool is_expandable_to_shape(const at::Tensor& grad) const { |
102 | // Currently NestedTensors are not expandable. If this support is added then |
103 | // updates to reduce_grad will be needed |
104 | TORCH_CHECK( |
105 | grad.is_nested() == is_nested_tensor(), |
106 | "Both grad and InputMetadata need to be either nested or non nested tensors." ) |
107 | return grad.is_nested() |
108 | ? false |
109 | : at::is_expandable_to(shape_as_dim_vector(), grad.sym_sizes()); |
110 | } |
111 | |
112 | at::Tensor reduce_grad(at::Tensor& grad) const { |
113 | // Currently reduce_grad is only called if is_expandable_to_shape returns |
114 | // true For nested tensors this always returns False, so this check |
115 | // shouldn't fail |
116 | TORCH_INTERNAL_ASSERT(!grad.is_nested() && !is_nested_tensor()) |
117 | return at::sum_to(std::move(grad), shape_as_dim_vector()); |
118 | } |
119 | |
120 | std::stringstream incompatible_shape_error_message( |
121 | const size_t index, |
122 | const at::Tensor& grad) const { |
123 | std::stringstream ss; |
124 | ss << "invalid gradient at index " << index << " - got " ; |
125 | if (grad.is_nested()) { |
126 | ss << at::native::get_nested_size_tensor(grad); |
127 | } else { |
128 | ss << grad.sym_sizes(); |
129 | } |
130 | ss << " but expected shape compatible with " ; |
131 | if (is_nested_tensor()) { |
132 | ss << shape_as_tensor(); |
133 | } else { |
134 | ss << shape_as_dim_vector(); |
135 | } |
136 | return ss; |
137 | } |
138 | |
139 | private: |
140 | bool is_nested_tensor() const { |
141 | return (c10::holds_alternative<at::Tensor>(shape_)); |
142 | } |
143 | MetadataShape compute_variant_shape(const at::Tensor& input) { |
144 | if (input.is_nested()) { |
145 | auto nested_size = at::native::get_nested_size_tensor(input); |
146 | return MetadataShape{c10::in_place_type<at::Tensor>, nested_size}; |
147 | } |
148 | return MetadataShape{c10::in_place_type<SymIntSmallVec>, input.sym_sizes()}; |
149 | } |
150 | |
151 | c10::SymIntArrayRef shape_as_dim_vector() const { |
152 | const auto& dim_shape = c10::get<SymIntSmallVec>(shape_); |
153 | return c10::SymIntArrayRef(dim_shape.data(), dim_shape.size()); |
154 | } |
155 | |
156 | at::Tensor shape_as_tensor() const { |
157 | return c10::get<at::Tensor>(shape_); |
158 | } |
159 | |
160 | const at::TensorOptions options_; |
161 | MetadataShape shape_; |
162 | c10::Stream stream_ = c10::Stream(c10::Stream::Default::DEFAULT, device()); |
163 | bool is_tensor_subclass_ = false; |
164 | }; |
165 | } // namespace autograd |
166 | } // namespace torch |
167 | |