1#pragma once
2
3#include <ATen/ExpandUtils.h>
4#include <ATen/NestedTensorImpl.h>
5#include <ATen/core/Tensor.h>
6#include <c10/core/Device.h>
7#include <c10/core/DeviceType.h>
8#include <c10/core/Stream.h>
9#include <c10/core/SymIntArrayRef.h>
10#include <c10/core/TensorImpl.h>
11#include <c10/core/impl/DeviceGuardImplInterface.h>
12#include <c10/util/DimVector.h>
13#include <c10/util/Exception.h>
14#include <c10/util/SmallVector.h>
15#include <c10/util/variant.h>
16
17#ifndef AT_PER_OPERATOR_HEADERS
18#include <ATen/Functions.h>
19#else
20#include <ATen/ops/zeros.h>
21#endif
22
23#include <cstdint>
24#include <utility>
25
26namespace torch {
27namespace autograd {
28
29using SymIntSmallVec = c10::SmallVector<c10::SymInt, c10::kDimVectorStaticSize>;
30using MetadataShape = c10::variant<SymIntSmallVec, at::Tensor>;
31
32/**
33 * Records TensorOptions, shape of the tensor, whether or not the Python
34 * dispatch key is set (tensor subclass), and, where applicable, the stream the
35 * corresponding operation took place on.
36 *
37 * If is_valid() is false, then the corresponding input is not used and may be
38 * an undefined tensor.
39 */
40struct InputMetadata {
41 InputMetadata() = default;
42
43 InputMetadata(
44 const at::TensorOptions options,
45 MetadataShape input_shape,
46 bool is_tensor_subclass)
47 : options_{options},
48 shape_{std::move(input_shape)},
49 is_tensor_subclass_{is_tensor_subclass} {
50 auto device_ = options.device();
51 stream_ = c10::impl::getDeviceGuardImpl(device_.type())->getStream(device_);
52 }
53
54 InputMetadata(const at::Tensor& t)
55 : InputMetadata(
56 t.options(),
57 compute_variant_shape(t),
58 t.unsafeGetTensorImpl()->is_python_dispatch()) {}
59
60 const at::TensorOptions options() const {
61 return options_;
62 }
63
64 caffe2::TypeMeta dtype() const {
65 return options_.dtype();
66 }
67
68 at::Device device() const {
69 return options_.device();
70 }
71
72 at::Layout layout() const {
73 return options_.layout();
74 }
75
76 c10::Stream stream() const {
77 return stream_;
78 }
79
80 bool is_tensor_subclass() const {
81 return is_tensor_subclass_;
82 }
83
84 at::Tensor zeros_like() const {
85 TORCH_CHECK(
86 !is_nested_tensor(),
87 "Zeros is not currently supported for nested tensors.")
88 return at::zeros_symint(shape_as_dim_vector(), options_);
89 }
90
91 bool is_same_shape(const at::Tensor& grad) const {
92 TORCH_CHECK(
93 grad.is_nested() == is_nested_tensor(),
94 "Both grad and InputMetadata need to be either nested or non nested tensors.")
95 if (grad.is_nested()) {
96 return at::native::get_nested_size_tensor(grad).is_same_size(
97 shape_as_tensor());
98 }
99 return grad.sym_sizes().equals(shape_as_dim_vector());
100 }
101 bool is_expandable_to_shape(const at::Tensor& grad) const {
102 // Currently NestedTensors are not expandable. If this support is added then
103 // updates to reduce_grad will be needed
104 TORCH_CHECK(
105 grad.is_nested() == is_nested_tensor(),
106 "Both grad and InputMetadata need to be either nested or non nested tensors.")
107 return grad.is_nested()
108 ? false
109 : at::is_expandable_to(shape_as_dim_vector(), grad.sym_sizes());
110 }
111
112 at::Tensor reduce_grad(at::Tensor& grad) const {
113 // Currently reduce_grad is only called if is_expandable_to_shape returns
114 // true For nested tensors this always returns False, so this check
115 // shouldn't fail
116 TORCH_INTERNAL_ASSERT(!grad.is_nested() && !is_nested_tensor())
117 return at::sum_to(std::move(grad), shape_as_dim_vector());
118 }
119
120 std::stringstream incompatible_shape_error_message(
121 const size_t index,
122 const at::Tensor& grad) const {
123 std::stringstream ss;
124 ss << "invalid gradient at index " << index << " - got ";
125 if (grad.is_nested()) {
126 ss << at::native::get_nested_size_tensor(grad);
127 } else {
128 ss << grad.sym_sizes();
129 }
130 ss << " but expected shape compatible with ";
131 if (is_nested_tensor()) {
132 ss << shape_as_tensor();
133 } else {
134 ss << shape_as_dim_vector();
135 }
136 return ss;
137 }
138
139 private:
140 bool is_nested_tensor() const {
141 return (c10::holds_alternative<at::Tensor>(shape_));
142 }
143 MetadataShape compute_variant_shape(const at::Tensor& input) {
144 if (input.is_nested()) {
145 auto nested_size = at::native::get_nested_size_tensor(input);
146 return MetadataShape{c10::in_place_type<at::Tensor>, nested_size};
147 }
148 return MetadataShape{c10::in_place_type<SymIntSmallVec>, input.sym_sizes()};
149 }
150
151 c10::SymIntArrayRef shape_as_dim_vector() const {
152 const auto& dim_shape = c10::get<SymIntSmallVec>(shape_);
153 return c10::SymIntArrayRef(dim_shape.data(), dim_shape.size());
154 }
155
156 at::Tensor shape_as_tensor() const {
157 return c10::get<at::Tensor>(shape_);
158 }
159
160 const at::TensorOptions options_;
161 MetadataShape shape_;
162 c10::Stream stream_ = c10::Stream(c10::Stream::Default::DEFAULT, device());
163 bool is_tensor_subclass_ = false;
164};
165} // namespace autograd
166} // namespace torch
167