1 | // required for old g++ to compile PRId64 macros, see |
2 | // https://github.com/pytorch/pytorch/issues/3571 |
3 | // for context |
4 | #ifndef __STDC_FORMAT_MACROS |
5 | #define __STDC_FORMAT_MACROS |
6 | #endif |
7 | |
8 | // an external backend might generate file within its code tree |
9 | // and check all the source files within the tree with clang-format. |
10 | // so, disable it since the backend might have a different config. |
11 | // clang-format off |
12 | |
13 | // NOTE: This condition is true for all PyTorch internal libraries, it |
14 | // just excludes external projects such as torch_xla which |
15 | // re-use some of the PyTorch codegen machinery. |
16 | #if defined(CAFFE2_BUILD_MAIN_LIB) || \ |
17 | defined(TORCH_CUDA_BUILD_MAIN_LIB) || \ |
18 | defined(TORCH_HIP_BUILD_MAIN_LIB) || \ |
19 | defined(TORCH_CUDA_CU_BUILD_MAIN_LIB) || \ |
20 | defined(TORCH_CUDA_CPP_BUILD_MAIN_LIB) |
21 | #define TORCH_ASSERT_ONLY_METHOD_OPERATORS |
22 | #endif |
23 | |
24 | // @generated by torchgen/gen.py from RegisterDispatchKey.cpp |
25 | |
26 | #include <c10/core/TensorImpl.h> |
27 | #include <c10/core/Allocator.h> |
28 | #include <ATen/DeviceGuard.h> |
29 | #include <ATen/NamedTensorUtils.h> |
30 | #include <ATen/Utils.h> |
31 | #include <ATen/WrapDimUtils.h> |
32 | #include <ATen/Dispatch.h> |
33 | #include <c10/util/ExclusivelyOwned.h> |
34 | #include <c10/util/Half.h> |
35 | #include <c10/core/UndefinedTensorImpl.h> |
36 | #include <c10/util/Optional.h> |
37 | #include <ATen/Tensor.h> |
38 | #include <ATen/native/Resize.h> |
39 | |
40 | #include <cstddef> |
41 | #include <functional> |
42 | #include <memory> |
43 | #include <utility> |
44 | |
45 | #include <ATen/Config.h> |
46 | #include <ATen/core/op_registration/adaption.h> |
47 | #include <torch/library.h> |
48 | |
49 | |
50 | #include <ATen/ops/as_strided_native.h> |
51 | #include <ATen/ops/empty.h> |
52 | #include <ATen/ops/empty_strided.h> |
53 | #include <ATen/ops/_copy_from_and_resize.h> |
54 | #include <ATen/ops/_copy_from.h> |
55 | #include <ATen/ops/_reshape_alias_native.h> |
56 | #include <ATen/ops/add_native.h> |
57 | #include <ATen/ops/as_strided_native.h> |
58 | #include <ATen/ops/div_native.h> |
59 | #include <ATen/ops/linalg_cross_native.h> |
60 | #include <ATen/ops/mul_native.h> |
61 | #include <ATen/ops/sub_native.h> |
62 | #include <ATen/ops/view_native.h> |
63 | |
64 | // See template file RegisterDispatchDefinitions.ini |
65 | namespace at { |
66 | // NB: TORCH_LIBRARY_IMPL must be in an anonymous namespace to avoid |
67 | // ambiguity with conflicting identifiers that may have been defined in |
68 | // at namespace already. |
69 | namespace { |
70 | void resize_out(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) { |
71 | TORCH_CHECK(options.dtype() == out.dtype(), |
72 | "Expected out tensor to have dtype " , options.dtype(), ", but got " , out.dtype(), " instead" ); |
73 | TORCH_CHECK(options.device() == out.device(), |
74 | "Expected out tensor to have device " , options.device(), ", but got " , out.device(), " instead" ); |
75 | const bool resized = at::native::resize_output(out, sizes); |
76 | // Only restride if a resize occurred; otherwise we ignore the (advisory) |
77 | // strides from the meta function and directly use the output tensor's |
78 | // preexisting strides |
79 | if (resized) { |
80 | if (!strides.empty()) { |
81 | TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value()); |
82 | // TODO: avoid the redispatch here |
83 | out.as_strided_(sizes, strides); |
84 | } else if (options.memory_format_opt().has_value()) { |
85 | out.unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt()); |
86 | } |
87 | } |
88 | } |
89 | void check_inplace(const Tensor &self, IntArrayRef sizes, const TensorOptions &options) { |
90 | // These checks are needed on those operators that: |
91 | // 1) don't use 'TensorIterator' (e.g. 'addmm' and 'baddbmm') |
92 | // 2) have particular typing rules (e.g. 'cumsum' and 'cumprod') |
93 | // For other operators (e.g. 'add'), 'TensorIterator' already checks |
94 | // these things separately. |
95 | TORCH_CHECK(options.dtype() == self.dtype(), |
96 | "Bad in-place call: " , |
97 | "input tensor dtype " , self.dtype(), " and output tensor dtype " , options.dtype(), " should match" ); |
98 | TORCH_CHECK(options.device() == self.device(), |
99 | "Bad in-place call: " , |
100 | "input tensor device " , self.device(), " and output tensor device " , options.device(), " should match" ); |
101 | TORCH_CHECK(sizes == self.sizes(), |
102 | "Bad in-place call: " , |
103 | "input tensor size " , self.sizes(), " and output tensor size " , sizes, " should match" ); |
104 | } |
105 | namespace { |
106 | at::Tensor wrapper_ZeroTensor_Tensor_add(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) { |
107 | // No device check |
108 | // DeviceGuard omitted |
109 | return at::native::add_zerotensor(self, other, alpha); |
110 | } |
111 | } // anonymous namespace |
112 | namespace { |
113 | at::Tensor wrapper_ZeroTensor__as_strided(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, c10::optional<c10::SymInt> storage_offset) { |
114 | // No device check |
115 | // DeviceGuard omitted |
116 | return at::native::as_strided_tensorimpl(self, C10_AS_INTARRAYREF_SLOW(size), C10_AS_INTARRAYREF_SLOW(stride), storage_offset.has_value() ? c10::make_optional(storage_offset->expect_int()) : c10::nullopt); |
117 | } |
118 | } // anonymous namespace |
119 | namespace { |
120 | at::Tensor wrapper_ZeroTensor_Tensor_div(const at::Tensor & self, const at::Tensor & other) { |
121 | // No device check |
122 | // DeviceGuard omitted |
123 | return at::native::div_zerotensor(self, other); |
124 | } |
125 | } // anonymous namespace |
126 | namespace { |
127 | at::Tensor wrapper_ZeroTensor_Tensor_mul(const at::Tensor & self, const at::Tensor & other) { |
128 | // No device check |
129 | // DeviceGuard omitted |
130 | return at::native::mul_zerotensor(self, other); |
131 | } |
132 | } // anonymous namespace |
133 | namespace { |
134 | at::Tensor wrapper_ZeroTensor___reshape_alias(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride) { |
135 | // No device check |
136 | // DeviceGuard omitted |
137 | return at::native::_reshape_alias(self, C10_AS_INTARRAYREF_SLOW(size), C10_AS_INTARRAYREF_SLOW(stride)); |
138 | } |
139 | } // anonymous namespace |
140 | namespace { |
141 | at::Tensor wrapper_ZeroTensor_Tensor_sub(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) { |
142 | // No device check |
143 | // DeviceGuard omitted |
144 | return at::native::sub_zerotensor(self, other, alpha); |
145 | } |
146 | } // anonymous namespace |
147 | namespace { |
148 | at::Tensor wrapper_ZeroTensor__view(const at::Tensor & self, c10::SymIntArrayRef size) { |
149 | // No device check |
150 | // DeviceGuard omitted |
151 | return at::native::view(self, C10_AS_INTARRAYREF_SLOW(size)); |
152 | } |
153 | } // anonymous namespace |
154 | namespace { |
155 | at::Tensor wrapper_ZeroTensor__linalg_cross(const at::Tensor & self, const at::Tensor & other, int64_t dim) { |
156 | // No device check |
157 | // DeviceGuard omitted |
158 | return at::native::linalg_cross_zerotensor(self, other, dim); |
159 | } |
160 | } // anonymous namespace |
161 | TORCH_LIBRARY_IMPL(aten, ZeroTensor, m) { |
162 | m.impl("add.Tensor" , |
163 | TORCH_FN(wrapper_ZeroTensor_Tensor_add)); |
164 | m.impl("as_strided" , |
165 | TORCH_FN(wrapper_ZeroTensor__as_strided)); |
166 | m.impl("div.Tensor" , |
167 | TORCH_FN(wrapper_ZeroTensor_Tensor_div)); |
168 | m.impl("mul.Tensor" , |
169 | TORCH_FN(wrapper_ZeroTensor_Tensor_mul)); |
170 | m.impl("_reshape_alias" , |
171 | TORCH_FN(wrapper_ZeroTensor___reshape_alias)); |
172 | m.impl("sub.Tensor" , |
173 | TORCH_FN(wrapper_ZeroTensor_Tensor_sub)); |
174 | m.impl("view" , |
175 | TORCH_FN(wrapper_ZeroTensor__view)); |
176 | m.impl("linalg_cross" , |
177 | TORCH_FN(wrapper_ZeroTensor__linalg_cross)); |
178 | }; |
179 | } // anonymous namespace |
180 | namespace zerotensor { |
181 | at::Tensor add(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) { |
182 | return wrapper_ZeroTensor_Tensor_add(self, other, alpha); |
183 | } |
184 | at::Tensor as_strided(const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, c10::optional<int64_t> storage_offset) { |
185 | return wrapper_ZeroTensor__as_strided(self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? c10::make_optional(c10::SymInt(*storage_offset)) : c10::nullopt); |
186 | } |
187 | at::Tensor as_strided_symint(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, c10::optional<c10::SymInt> storage_offset) { |
188 | return wrapper_ZeroTensor__as_strided(self, size, stride, storage_offset); |
189 | } |
190 | at::Tensor div(const at::Tensor & self, const at::Tensor & other) { |
191 | return wrapper_ZeroTensor_Tensor_div(self, other); |
192 | } |
193 | at::Tensor mul(const at::Tensor & self, const at::Tensor & other) { |
194 | return wrapper_ZeroTensor_Tensor_mul(self, other); |
195 | } |
196 | at::Tensor _reshape_alias(const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride) { |
197 | return wrapper_ZeroTensor___reshape_alias(self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride)); |
198 | } |
199 | at::Tensor _reshape_alias_symint(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride) { |
200 | return wrapper_ZeroTensor___reshape_alias(self, size, stride); |
201 | } |
202 | at::Tensor sub(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) { |
203 | return wrapper_ZeroTensor_Tensor_sub(self, other, alpha); |
204 | } |
205 | at::Tensor view(const at::Tensor & self, at::IntArrayRef size) { |
206 | return wrapper_ZeroTensor__view(self, c10::fromIntArrayRefSlow(size)); |
207 | } |
208 | at::Tensor view_symint(const at::Tensor & self, c10::SymIntArrayRef size) { |
209 | return wrapper_ZeroTensor__view(self, size); |
210 | } |
211 | at::Tensor linalg_cross(const at::Tensor & self, const at::Tensor & other, int64_t dim) { |
212 | return wrapper_ZeroTensor__linalg_cross(self, other, dim); |
213 | } |
214 | } // namespace zerotensor |
215 | } // namespace at |
216 | |