1// required for old g++ to compile PRId64 macros, see
2// https://github.com/pytorch/pytorch/issues/3571
3// for context
4#ifndef __STDC_FORMAT_MACROS
5#define __STDC_FORMAT_MACROS
6#endif
7
8// an external backend might generate file within its code tree
9// and check all the source files within the tree with clang-format.
10// so, disable it since the backend might have a different config.
11// clang-format off
12
13// NOTE: This condition is true for all PyTorch internal libraries, it
14// just excludes external projects such as torch_xla which
15// re-use some of the PyTorch codegen machinery.
16#if defined(CAFFE2_BUILD_MAIN_LIB) || \
17 defined(TORCH_CUDA_BUILD_MAIN_LIB) || \
18 defined(TORCH_HIP_BUILD_MAIN_LIB) || \
19 defined(TORCH_CUDA_CU_BUILD_MAIN_LIB) || \
20 defined(TORCH_CUDA_CPP_BUILD_MAIN_LIB)
21#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
22#endif
23
24// @generated by torchgen/gen.py from RegisterDispatchKey.cpp
25
26#include <c10/core/TensorImpl.h>
27#include <c10/core/Allocator.h>
28#include <ATen/DeviceGuard.h>
29#include <ATen/NamedTensorUtils.h>
30#include <ATen/Utils.h>
31#include <ATen/WrapDimUtils.h>
32#include <ATen/Dispatch.h>
33#include <c10/util/ExclusivelyOwned.h>
34#include <c10/util/Half.h>
35#include <c10/core/UndefinedTensorImpl.h>
36#include <c10/util/Optional.h>
37#include <ATen/Tensor.h>
38#include <ATen/native/Resize.h>
39
40#include <cstddef>
41#include <functional>
42#include <memory>
43#include <utility>
44
45#include <ATen/Config.h>
46#include <ATen/core/op_registration/adaption.h>
47#include <torch/library.h>
48
49
50#include <ATen/ops/as_strided_native.h>
51#include <ATen/ops/empty.h>
52#include <ATen/ops/empty_strided.h>
53#include <ATen/ops/_copy_from_and_resize.h>
54#include <ATen/ops/_copy_from.h>
55#include <ATen/ops/empty_native.h>
56
57// See template file RegisterDispatchDefinitions.ini
58namespace at {
59// NB: TORCH_LIBRARY_IMPL must be in an anonymous namespace to avoid
60// ambiguity with conflicting identifiers that may have been defined in
61// at namespace already.
62namespace {
63void resize_out(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {
64 TORCH_CHECK(options.dtype() == out.dtype(),
65 "Expected out tensor to have dtype ", options.dtype(), ", but got ", out.dtype(), " instead");
66 TORCH_CHECK(options.device() == out.device(),
67 "Expected out tensor to have device ", options.device(), ", but got ", out.device(), " instead");
68 const bool resized = at::native::resize_output(out, sizes);
69 // Only restride if a resize occurred; otherwise we ignore the (advisory)
70 // strides from the meta function and directly use the output tensor's
71 // preexisting strides
72 if (resized) {
73 if (!strides.empty()) {
74 TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value());
75 // TODO: avoid the redispatch here
76 out.as_strided_(sizes, strides);
77 } else if (options.memory_format_opt().has_value()) {
78 out.unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt());
79 }
80 }
81}
82void check_inplace(const Tensor &self, IntArrayRef sizes, const TensorOptions &options) {
83 // These checks are needed on those operators that:
84 // 1) don't use 'TensorIterator' (e.g. 'addmm' and 'baddbmm')
85 // 2) have particular typing rules (e.g. 'cumsum' and 'cumprod')
86 // For other operators (e.g. 'add'), 'TensorIterator' already checks
87 // these things separately.
88 TORCH_CHECK(options.dtype() == self.dtype(),
89 "Bad in-place call: ",
90 "input tensor dtype ", self.dtype(), " and output tensor dtype ", options.dtype(), " should match");
91 TORCH_CHECK(options.device() == self.device(),
92 "Bad in-place call: ",
93 "input tensor device ", self.device(), " and output tensor device ", options.device(), " should match");
94 TORCH_CHECK(sizes == self.sizes(),
95 "Bad in-place call: ",
96 "input tensor size ", self.sizes(), " and output tensor size ", sizes, " should match");
97}
98namespace {
99at::Tensor wrapper_QuantizedMeta_memory_format_empty(c10::SymIntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory, c10::optional<at::MemoryFormat> memory_format) {
100 // No device check
101 // DeviceGuard omitted
102 return at::native::empty_unknown_quantized(C10_AS_INTARRAYREF_SLOW(size), dtype, layout, device, pin_memory, memory_format);
103}
104} // anonymous namespace
105TORCH_LIBRARY_IMPL(aten, QuantizedMeta, m) {
106 m.impl("empty.memory_format",
107TORCH_FN(wrapper_QuantizedMeta_memory_format_empty));
108};
109} // anonymous namespace
110namespace quantizedmeta {
111at::Tensor empty(at::IntArrayRef size, at::TensorOptions options, c10::optional<at::MemoryFormat> memory_format) {
112return wrapper_QuantizedMeta_memory_format_empty(c10::fromIntArrayRefSlow(size), optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format));
113}
114at::Tensor empty(at::IntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory, c10::optional<at::MemoryFormat> memory_format) {
115return wrapper_QuantizedMeta_memory_format_empty(c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory, memory_format);
116}
117at::Tensor empty_symint(c10::SymIntArrayRef size, at::TensorOptions options, c10::optional<at::MemoryFormat> memory_format) {
118return wrapper_QuantizedMeta_memory_format_empty(size, optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format));
119}
120at::Tensor empty_symint(c10::SymIntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory, c10::optional<at::MemoryFormat> memory_format) {
121return wrapper_QuantizedMeta_memory_format_empty(size, dtype, layout, device, pin_memory, memory_format);
122}
123} // namespace quantizedmeta
124} // namespace at
125