1#pragma once
2
3#include <ATen/DimVector.h>
4#include <ATen/core/Dimname.h>
5#include <c10/core/TensorOptions.h>
6#include <c10/util/strides.h>
7
8C10_CLANG_DIAGNOSTIC_PUSH()
9#if C10_CLANG_HAS_WARNING("-Wdeprecated-copy-dtor")
10C10_CLANG_DIAGNOSTIC_IGNORE("-Wdeprecated-copy-dtor")
11#endif
12
13namespace at {
14
15class Tensor;
16
17namespace impl {
18
19// Use this to define the prototype for a meta function. There are two
20// versions; one that takes one argument (just the operator name), or FUNC2
21// variant that takes two arguments (operator name and overload name).
22//
23// Example usage:
24//
25// TORCH_META_FUNC2(add, Tensor) (
26// const Tensor& self, const Tensor& other
27// ) {
28// ... compute sizes and options ...
29// set_output(sizes, options);
30// }
31//
32#define TORCH_META_FUNC(name) void structured_##name::meta
33#define TORCH_META_FUNC2(name, overload) \
34 void structured_##name##_##overload::meta
35
36// These are versions of TORCH_META_FUNC(2) that include a precompute_out struct
37// as a return value. They should be used when the kernel in question has
38// precomputed values declared in native_functions.yaml and the corresponding
39// implementation should return an instance of the aforementioned struct.
40#define TORCH_PRECOMPUTE_META_FUNC(name) \
41 structured_##name::meta_return_ty structured_##name::meta
42#define TORCH_PRECOMPUTE_META_FUNC2(name, overload) \
43 structured_##name##_##overload::meta_return_ty \
44 structured_##name##_##overload::meta
45
46// Use this to create a precompute struct in a meta function.
47#define TORCH_PRECOMPUTE_STRUCT(name) structured_##name::precompute_out<>
48#define TORCH_PRECOMPUTE_STRUCT2(name, overload) \
49 structured_##name##_##overload::precompute_out<>
50
51// Use this to define the prototype for an implementation. This takes only
52// one argument, which is the name of the dispatch key entry you're
53// implementing.
54//
55// Example usage:
56//
57// TORCH_IMPL_FUNC(add_cpu) (
58// Tensor& result, const Tensor& self, const Tensor& other
59// ) {
60// ... do the actual implementation ...
61// }
62//
63#define TORCH_IMPL_FUNC(name) void structured_##name::impl
64
65// Base class for all structured kernel classes. The set_output virtual
66// method is varied depending whether or not the operator is
67// functional/out/inplace, and could also be specialized for CPU/CUDA/etc
68// (although presently it isn't).
69//
70// A notable subclass of this interface is TensorIteratorBase.
71struct TORCH_API MetaBase {
72 virtual const Tensor& maybe_get_output(int64_t output_idx) = 0;
73
74 // Note: [set_output_*]
75 // See: https://github.com/pytorch/pytorch/issues/69813
76 // Whenever defining the output properties in the META function of a
77 // structured kernel (what was usually done with `set_output`), use one of
78 // these 3 variants, instead. In order to decide which variant to use, check
79 // the following decision tree:
80 //
81 // - Can the kernel you are going to implement support output tensors
82 // with arbitrary strides?
83 // |
84 // -- YES: `set_output_raw_strided`
85 // |
86 // -- NO: Should the output tensor strides be contiguous?
87 // |
88 // -- YES: `set_output_contiguous`
89 // |
90 // -- NO: `set_output_strided`
91 //
92 // Use this function whenever the kernel requires specific strides for the
93 // output. If `strides` does not match the given output strides, proxy outputs
94 // will be created and passed to the IMPL function.
95 virtual void set_output_strided(
96 int64_t output_idx,
97 IntArrayRef sizes,
98 IntArrayRef strides,
99 TensorOptions options,
100 DimnameList names = {}) {
101 TORCH_INTERNAL_ASSERT(false, "set_output_strided not implemented.");
102 }
103
104 // Use this function whenever the kernel knows how to handle arbitrary strided
105 // outputs. This function has the same behavior as the old `set_output`: it
106 // will only re-stride if the given output was resized.
107 virtual void set_output_raw_strided(
108 int64_t output_idx,
109 IntArrayRef sizes,
110 IntArrayRef strides_hint,
111 TensorOptions options,
112 DimnameList names = {}) {
113 TORCH_INTERNAL_ASSERT(false, "set_output_strided not implemented.");
114 }
115
116 // Use this function if the kernel requires contiguous strides.
117 // Alias for `set_output_strided`, but with contiguous strides.
118 void set_output_contiguous(
119 int64_t output_idx,
120 IntArrayRef sizes,
121 TensorOptions options,
122 DimnameList names = {}) {
123 auto strides = c10::contiguous_strides(sizes);
124 set_output_strided(output_idx, sizes, strides, options, names);
125 }
126
127 // Returns a reference to an undefined tensor if there is no presupplied
128 // output
129 const Tensor& maybe_get_output() {
130 return maybe_get_output(0);
131 }
132 virtual ~MetaBase() = default;
133};
134
135} // namespace impl
136
137} // namespace at
138
139C10_CLANG_DIAGNOSTIC_POP()
140