1// required for old g++ to compile PRId64 macros, see
2// https://github.com/pytorch/pytorch/issues/3571
3// for context
4#ifndef __STDC_FORMAT_MACROS
5#define __STDC_FORMAT_MACROS
6#endif
7
8// an external backend might generate file within its code tree
9// and check all the source files within the tree with clang-format.
10// so, disable it since the backend might have a different config.
11// clang-format off
12
13// NOTE: This condition is true for all PyTorch internal libraries, it
14// just excludes external projects such as torch_xla which
15// re-use some of the PyTorch codegen machinery.
16#if defined(CAFFE2_BUILD_MAIN_LIB) || \
17 defined(TORCH_CUDA_BUILD_MAIN_LIB) || \
18 defined(TORCH_HIP_BUILD_MAIN_LIB) || \
19 defined(TORCH_CUDA_CU_BUILD_MAIN_LIB) || \
20 defined(TORCH_CUDA_CPP_BUILD_MAIN_LIB)
21#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
22#endif
23
24// @generated by torchgen/gen.py from RegisterDispatchKey.cpp
25
26#include <c10/core/TensorImpl.h>
27#include <c10/core/Allocator.h>
28#include <ATen/DeviceGuard.h>
29#include <ATen/NamedTensorUtils.h>
30#include <ATen/Utils.h>
31#include <ATen/WrapDimUtils.h>
32#include <ATen/Dispatch.h>
33#include <c10/util/ExclusivelyOwned.h>
34#include <c10/util/Half.h>
35#include <c10/core/UndefinedTensorImpl.h>
36#include <c10/util/Optional.h>
37#include <ATen/Tensor.h>
38#include <ATen/native/Resize.h>
39
40#include <cstddef>
41#include <functional>
42#include <memory>
43#include <utility>
44
45#include <ATen/Config.h>
46#include <ATen/core/op_registration/adaption.h>
47#include <torch/library.h>
48
49
50#include <ATen/ops/as_strided_native.h>
51#include <ATen/ops/empty.h>
52#include <ATen/ops/empty_strided.h>
53#include <ATen/ops/_copy_from_and_resize.h>
54#include <ATen/ops/_copy_from.h>
55#include <ATen/ops/_reshape_alias_native.h>
56#include <ATen/ops/add_native.h>
57#include <ATen/ops/as_strided_native.h>
58#include <ATen/ops/div_native.h>
59#include <ATen/ops/linalg_cross_native.h>
60#include <ATen/ops/mul_native.h>
61#include <ATen/ops/sub_native.h>
62#include <ATen/ops/view_native.h>
63
64// See template file RegisterDispatchDefinitions.ini
65namespace at {
66// NB: TORCH_LIBRARY_IMPL must be in an anonymous namespace to avoid
67// ambiguity with conflicting identifiers that may have been defined in
68// at namespace already.
69namespace {
70void resize_out(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {
71 TORCH_CHECK(options.dtype() == out.dtype(),
72 "Expected out tensor to have dtype ", options.dtype(), ", but got ", out.dtype(), " instead");
73 TORCH_CHECK(options.device() == out.device(),
74 "Expected out tensor to have device ", options.device(), ", but got ", out.device(), " instead");
75 const bool resized = at::native::resize_output(out, sizes);
76 // Only restride if a resize occurred; otherwise we ignore the (advisory)
77 // strides from the meta function and directly use the output tensor's
78 // preexisting strides
79 if (resized) {
80 if (!strides.empty()) {
81 TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value());
82 // TODO: avoid the redispatch here
83 out.as_strided_(sizes, strides);
84 } else if (options.memory_format_opt().has_value()) {
85 out.unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt());
86 }
87 }
88}
89void check_inplace(const Tensor &self, IntArrayRef sizes, const TensorOptions &options) {
90 // These checks are needed on those operators that:
91 // 1) don't use 'TensorIterator' (e.g. 'addmm' and 'baddbmm')
92 // 2) have particular typing rules (e.g. 'cumsum' and 'cumprod')
93 // For other operators (e.g. 'add'), 'TensorIterator' already checks
94 // these things separately.
95 TORCH_CHECK(options.dtype() == self.dtype(),
96 "Bad in-place call: ",
97 "input tensor dtype ", self.dtype(), " and output tensor dtype ", options.dtype(), " should match");
98 TORCH_CHECK(options.device() == self.device(),
99 "Bad in-place call: ",
100 "input tensor device ", self.device(), " and output tensor device ", options.device(), " should match");
101 TORCH_CHECK(sizes == self.sizes(),
102 "Bad in-place call: ",
103 "input tensor size ", self.sizes(), " and output tensor size ", sizes, " should match");
104}
105namespace {
106at::Tensor wrapper_ZeroTensor_Tensor_add(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
107 // No device check
108 // DeviceGuard omitted
109 return at::native::add_zerotensor(self, other, alpha);
110}
111} // anonymous namespace
112namespace {
113at::Tensor wrapper_ZeroTensor__as_strided(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, c10::optional<c10::SymInt> storage_offset) {
114 // No device check
115 // DeviceGuard omitted
116 return at::native::as_strided_tensorimpl(self, C10_AS_INTARRAYREF_SLOW(size), C10_AS_INTARRAYREF_SLOW(stride), storage_offset.has_value() ? c10::make_optional(storage_offset->expect_int()) : c10::nullopt);
117}
118} // anonymous namespace
119namespace {
120at::Tensor wrapper_ZeroTensor_Tensor_div(const at::Tensor & self, const at::Tensor & other) {
121 // No device check
122 // DeviceGuard omitted
123 return at::native::div_zerotensor(self, other);
124}
125} // anonymous namespace
126namespace {
127at::Tensor wrapper_ZeroTensor_Tensor_mul(const at::Tensor & self, const at::Tensor & other) {
128 // No device check
129 // DeviceGuard omitted
130 return at::native::mul_zerotensor(self, other);
131}
132} // anonymous namespace
133namespace {
134at::Tensor wrapper_ZeroTensor___reshape_alias(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride) {
135 // No device check
136 // DeviceGuard omitted
137 return at::native::_reshape_alias(self, C10_AS_INTARRAYREF_SLOW(size), C10_AS_INTARRAYREF_SLOW(stride));
138}
139} // anonymous namespace
140namespace {
141at::Tensor wrapper_ZeroTensor_Tensor_sub(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
142 // No device check
143 // DeviceGuard omitted
144 return at::native::sub_zerotensor(self, other, alpha);
145}
146} // anonymous namespace
147namespace {
148at::Tensor wrapper_ZeroTensor__view(const at::Tensor & self, c10::SymIntArrayRef size) {
149 // No device check
150 // DeviceGuard omitted
151 return at::native::view(self, C10_AS_INTARRAYREF_SLOW(size));
152}
153} // anonymous namespace
154namespace {
155at::Tensor wrapper_ZeroTensor__linalg_cross(const at::Tensor & self, const at::Tensor & other, int64_t dim) {
156 // No device check
157 // DeviceGuard omitted
158 return at::native::linalg_cross_zerotensor(self, other, dim);
159}
160} // anonymous namespace
161TORCH_LIBRARY_IMPL(aten, ZeroTensor, m) {
162 m.impl("add.Tensor",
163TORCH_FN(wrapper_ZeroTensor_Tensor_add));
164m.impl("as_strided",
165TORCH_FN(wrapper_ZeroTensor__as_strided));
166m.impl("div.Tensor",
167TORCH_FN(wrapper_ZeroTensor_Tensor_div));
168m.impl("mul.Tensor",
169TORCH_FN(wrapper_ZeroTensor_Tensor_mul));
170m.impl("_reshape_alias",
171TORCH_FN(wrapper_ZeroTensor___reshape_alias));
172m.impl("sub.Tensor",
173TORCH_FN(wrapper_ZeroTensor_Tensor_sub));
174m.impl("view",
175TORCH_FN(wrapper_ZeroTensor__view));
176m.impl("linalg_cross",
177TORCH_FN(wrapper_ZeroTensor__linalg_cross));
178};
179} // anonymous namespace
180namespace zerotensor {
181at::Tensor add(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
182return wrapper_ZeroTensor_Tensor_add(self, other, alpha);
183}
184at::Tensor as_strided(const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, c10::optional<int64_t> storage_offset) {
185return wrapper_ZeroTensor__as_strided(self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? c10::make_optional(c10::SymInt(*storage_offset)) : c10::nullopt);
186}
187at::Tensor as_strided_symint(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, c10::optional<c10::SymInt> storage_offset) {
188return wrapper_ZeroTensor__as_strided(self, size, stride, storage_offset);
189}
190at::Tensor div(const at::Tensor & self, const at::Tensor & other) {
191return wrapper_ZeroTensor_Tensor_div(self, other);
192}
193at::Tensor mul(const at::Tensor & self, const at::Tensor & other) {
194return wrapper_ZeroTensor_Tensor_mul(self, other);
195}
196at::Tensor _reshape_alias(const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride) {
197return wrapper_ZeroTensor___reshape_alias(self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride));
198}
199at::Tensor _reshape_alias_symint(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride) {
200return wrapper_ZeroTensor___reshape_alias(self, size, stride);
201}
202at::Tensor sub(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
203return wrapper_ZeroTensor_Tensor_sub(self, other, alpha);
204}
205at::Tensor view(const at::Tensor & self, at::IntArrayRef size) {
206return wrapper_ZeroTensor__view(self, c10::fromIntArrayRefSlow(size));
207}
208at::Tensor view_symint(const at::Tensor & self, c10::SymIntArrayRef size) {
209return wrapper_ZeroTensor__view(self, size);
210}
211at::Tensor linalg_cross(const at::Tensor & self, const at::Tensor & other, int64_t dim) {
212return wrapper_ZeroTensor__linalg_cross(self, other, dim);
213}
214} // namespace zerotensor
215} // namespace at
216