1#pragma once
2
3#include <string>
4
5#include <ATen/core/Reduction.h>
6#include <c10/util/Exception.h>
7#include <c10/util/variant.h>
8#include <torch/csrc/Export.h>
9
10#define TORCH_ENUM_DECLARE(name) \
11 namespace torch { \
12 namespace enumtype { \
13 /* \
14 NOTE: We need to provide the default constructor for each struct, \
15 otherwise Clang 3.8 would complain: \
16 ``` \
17 error: default initialization of an object of const type 'const \
18 enumtype::Enum1' without a user-provided default constructor \
19 ``` \
20 */ \
21 struct k##name { \
22 k##name() {} \
23 }; \
24 } \
25 TORCH_API extern const enumtype::k##name k##name; \
26 }
27
28#define TORCH_ENUM_DEFINE(name) \
29 namespace torch { \
30 const enumtype::k##name k##name; \
31 }
32
33#define TORCH_ENUM_PRETTY_PRINT(name) \
34 std::string operator()(const enumtype::k##name& v) const { \
35 std::string k("k"); \
36 return k + #name; \
37 }
38
39// NOTE: Backstory on why we need the following two macros:
40//
41// Consider the following options class:
42//
43// ```
44// struct TORCH_API SomeOptions {
45// typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
46// reduction_t; SomeOptions(reduction_t reduction = torch::kMean) :
47// reduction_(reduction) {}
48//
49// TORCH_ARG(reduction_t, reduction);
50// };
51// ```
52//
53// and the functional that uses it:
54//
55// ```
56// Tensor some_functional(
57// const Tensor& input,
58// SomeOptions options = {}) {
59// ...
60// }
61// ```
62//
63// Normally, we would expect this to work:
64//
65// `F::some_functional(input, torch::kNone)`
66//
67// However, it throws the following error instead:
68//
69// ```
70// error: could not convert `torch::kNone` from `const torch::enumtype::kNone`
71// to `torch::nn::SomeOptions`
72// ```
73//
74// To get around this problem, we explicitly provide the following constructors
75// for `SomeOptions`:
76//
77// ```
78// SomeOptions(torch::enumtype::kNone reduction) : reduction_(torch::kNone) {}
79// SomeOptions(torch::enumtype::kMean reduction) : reduction_(torch::kMean) {}
80// SomeOptions(torch::enumtype::kSum reduction) : reduction_(torch::kSum) {}
81// ```
82//
83// so that the conversion from `torch::kNone` to `SomeOptions` would work.
84//
85// Note that we also provide the default constructor `SomeOptions() {}`, so that
86// `SomeOptions options = {}` can work.
87#define TORCH_OPTIONS_CTOR_VARIANT_ARG3( \
88 OPTIONS_NAME, ARG_NAME, TYPE1, TYPE2, TYPE3) \
89 OPTIONS_NAME() = default; \
90 OPTIONS_NAME(torch::enumtype::TYPE1 ARG_NAME) : ARG_NAME##_(torch::TYPE1) {} \
91 OPTIONS_NAME(torch::enumtype::TYPE2 ARG_NAME) : ARG_NAME##_(torch::TYPE2) {} \
92 OPTIONS_NAME(torch::enumtype::TYPE3 ARG_NAME) : ARG_NAME##_(torch::TYPE3) {}
93
94#define TORCH_OPTIONS_CTOR_VARIANT_ARG4( \
95 OPTIONS_NAME, ARG_NAME, TYPE1, TYPE2, TYPE3, TYPE4) \
96 OPTIONS_NAME() = default; \
97 OPTIONS_NAME(torch::enumtype::TYPE1 ARG_NAME) : ARG_NAME##_(torch::TYPE1) {} \
98 OPTIONS_NAME(torch::enumtype::TYPE2 ARG_NAME) : ARG_NAME##_(torch::TYPE2) {} \
99 OPTIONS_NAME(torch::enumtype::TYPE3 ARG_NAME) : ARG_NAME##_(torch::TYPE3) {} \
100 OPTIONS_NAME(torch::enumtype::TYPE4 ARG_NAME) : ARG_NAME##_(torch::TYPE4) {}
101
102TORCH_ENUM_DECLARE(Linear)
103TORCH_ENUM_DECLARE(Conv1D)
104TORCH_ENUM_DECLARE(Conv2D)
105TORCH_ENUM_DECLARE(Conv3D)
106TORCH_ENUM_DECLARE(ConvTranspose1D)
107TORCH_ENUM_DECLARE(ConvTranspose2D)
108TORCH_ENUM_DECLARE(ConvTranspose3D)
109TORCH_ENUM_DECLARE(Sigmoid)
110TORCH_ENUM_DECLARE(Tanh)
111TORCH_ENUM_DECLARE(ReLU)
112TORCH_ENUM_DECLARE(GELU)
113TORCH_ENUM_DECLARE(SiLU)
114TORCH_ENUM_DECLARE(Mish)
115TORCH_ENUM_DECLARE(LeakyReLU)
116TORCH_ENUM_DECLARE(FanIn)
117TORCH_ENUM_DECLARE(FanOut)
118TORCH_ENUM_DECLARE(Constant)
119TORCH_ENUM_DECLARE(Reflect)
120TORCH_ENUM_DECLARE(Replicate)
121TORCH_ENUM_DECLARE(Circular)
122TORCH_ENUM_DECLARE(Nearest)
123TORCH_ENUM_DECLARE(Bilinear)
124TORCH_ENUM_DECLARE(Bicubic)
125TORCH_ENUM_DECLARE(Trilinear)
126TORCH_ENUM_DECLARE(Area)
127TORCH_ENUM_DECLARE(NearestExact)
128TORCH_ENUM_DECLARE(Sum)
129TORCH_ENUM_DECLARE(Mean)
130TORCH_ENUM_DECLARE(Max)
131TORCH_ENUM_DECLARE(None)
132TORCH_ENUM_DECLARE(BatchMean)
133TORCH_ENUM_DECLARE(Zeros)
134TORCH_ENUM_DECLARE(Border)
135TORCH_ENUM_DECLARE(Reflection)
136TORCH_ENUM_DECLARE(RNN_TANH)
137TORCH_ENUM_DECLARE(RNN_RELU)
138TORCH_ENUM_DECLARE(LSTM)
139TORCH_ENUM_DECLARE(GRU)
140TORCH_ENUM_DECLARE(Valid)
141TORCH_ENUM_DECLARE(Same)
142
143namespace torch {
144namespace enumtype {
145
146struct _compute_enum_name {
147 TORCH_ENUM_PRETTY_PRINT(Linear)
148 TORCH_ENUM_PRETTY_PRINT(Conv1D)
149 TORCH_ENUM_PRETTY_PRINT(Conv2D)
150 TORCH_ENUM_PRETTY_PRINT(Conv3D)
151 TORCH_ENUM_PRETTY_PRINT(ConvTranspose1D)
152 TORCH_ENUM_PRETTY_PRINT(ConvTranspose2D)
153 TORCH_ENUM_PRETTY_PRINT(ConvTranspose3D)
154 TORCH_ENUM_PRETTY_PRINT(Sigmoid)
155 TORCH_ENUM_PRETTY_PRINT(Tanh)
156 TORCH_ENUM_PRETTY_PRINT(ReLU)
157 TORCH_ENUM_PRETTY_PRINT(GELU)
158 TORCH_ENUM_PRETTY_PRINT(SiLU)
159 TORCH_ENUM_PRETTY_PRINT(Mish)
160 TORCH_ENUM_PRETTY_PRINT(LeakyReLU)
161 TORCH_ENUM_PRETTY_PRINT(FanIn)
162 TORCH_ENUM_PRETTY_PRINT(FanOut)
163 TORCH_ENUM_PRETTY_PRINT(Constant)
164 TORCH_ENUM_PRETTY_PRINT(Reflect)
165 TORCH_ENUM_PRETTY_PRINT(Replicate)
166 TORCH_ENUM_PRETTY_PRINT(Circular)
167 TORCH_ENUM_PRETTY_PRINT(Nearest)
168 TORCH_ENUM_PRETTY_PRINT(Bilinear)
169 TORCH_ENUM_PRETTY_PRINT(Bicubic)
170 TORCH_ENUM_PRETTY_PRINT(Trilinear)
171 TORCH_ENUM_PRETTY_PRINT(Area)
172 TORCH_ENUM_PRETTY_PRINT(NearestExact)
173 TORCH_ENUM_PRETTY_PRINT(Sum)
174 TORCH_ENUM_PRETTY_PRINT(Mean)
175 TORCH_ENUM_PRETTY_PRINT(Max)
176 TORCH_ENUM_PRETTY_PRINT(None)
177 TORCH_ENUM_PRETTY_PRINT(BatchMean)
178 TORCH_ENUM_PRETTY_PRINT(Zeros)
179 TORCH_ENUM_PRETTY_PRINT(Border)
180 TORCH_ENUM_PRETTY_PRINT(Reflection)
181 TORCH_ENUM_PRETTY_PRINT(RNN_TANH)
182 TORCH_ENUM_PRETTY_PRINT(RNN_RELU)
183 TORCH_ENUM_PRETTY_PRINT(LSTM)
184 TORCH_ENUM_PRETTY_PRINT(GRU)
185 TORCH_ENUM_PRETTY_PRINT(Valid)
186 TORCH_ENUM_PRETTY_PRINT(Same)
187};
188
189template <typename V>
190std::string get_enum_name(V variant_enum) {
191 return c10::visit(enumtype::_compute_enum_name{}, variant_enum);
192}
193
194template <typename V>
195at::Reduction::Reduction reduction_get_enum(V variant_enum) {
196 if (c10::get_if<enumtype::kNone>(&variant_enum)) {
197 return at::Reduction::None;
198 } else if (c10::get_if<enumtype::kMean>(&variant_enum)) {
199 return at::Reduction::Mean;
200 } else if (c10::get_if<enumtype::kSum>(&variant_enum)) {
201 return at::Reduction::Sum;
202 } else {
203 TORCH_CHECK(
204 false,
205 get_enum_name(variant_enum),
206 " is not a valid value for reduction");
207 return at::Reduction::END;
208 }
209}
210
211} // namespace enumtype
212} // namespace torch
213