1 | #pragma once |
2 | |
3 | #include <string> |
4 | |
5 | #include <ATen/core/Reduction.h> |
6 | #include <c10/util/Exception.h> |
7 | #include <c10/util/variant.h> |
8 | #include <torch/csrc/Export.h> |
9 | |
10 | #define TORCH_ENUM_DECLARE(name) \ |
11 | namespace torch { \ |
12 | namespace enumtype { \ |
13 | /* \ |
14 | NOTE: We need to provide the default constructor for each struct, \ |
15 | otherwise Clang 3.8 would complain: \ |
16 | ``` \ |
17 | error: default initialization of an object of const type 'const \ |
18 | enumtype::Enum1' without a user-provided default constructor \ |
19 | ``` \ |
20 | */ \ |
21 | struct k##name { \ |
22 | k##name() {} \ |
23 | }; \ |
24 | } \ |
25 | TORCH_API extern const enumtype::k##name k##name; \ |
26 | } |
27 | |
28 | #define TORCH_ENUM_DEFINE(name) \ |
29 | namespace torch { \ |
30 | const enumtype::k##name k##name; \ |
31 | } |
32 | |
33 | #define TORCH_ENUM_PRETTY_PRINT(name) \ |
34 | std::string operator()(const enumtype::k##name& v) const { \ |
35 | std::string k("k"); \ |
36 | return k + #name; \ |
37 | } |
38 | |
39 | // NOTE: Backstory on why we need the following two macros: |
40 | // |
41 | // Consider the following options class: |
42 | // |
43 | // ``` |
44 | // struct TORCH_API SomeOptions { |
45 | // typedef c10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum> |
46 | // reduction_t; SomeOptions(reduction_t reduction = torch::kMean) : |
47 | // reduction_(reduction) {} |
48 | // |
49 | // TORCH_ARG(reduction_t, reduction); |
50 | // }; |
51 | // ``` |
52 | // |
53 | // and the functional that uses it: |
54 | // |
55 | // ``` |
56 | // Tensor some_functional( |
57 | // const Tensor& input, |
58 | // SomeOptions options = {}) { |
59 | // ... |
60 | // } |
61 | // ``` |
62 | // |
63 | // Normally, we would expect this to work: |
64 | // |
65 | // `F::some_functional(input, torch::kNone)` |
66 | // |
67 | // However, it throws the following error instead: |
68 | // |
69 | // ``` |
70 | // error: could not convert `torch::kNone` from `const torch::enumtype::kNone` |
71 | // to `torch::nn::SomeOptions` |
72 | // ``` |
73 | // |
74 | // To get around this problem, we explicitly provide the following constructors |
75 | // for `SomeOptions`: |
76 | // |
77 | // ``` |
78 | // SomeOptions(torch::enumtype::kNone reduction) : reduction_(torch::kNone) {} |
79 | // SomeOptions(torch::enumtype::kMean reduction) : reduction_(torch::kMean) {} |
80 | // SomeOptions(torch::enumtype::kSum reduction) : reduction_(torch::kSum) {} |
81 | // ``` |
82 | // |
83 | // so that the conversion from `torch::kNone` to `SomeOptions` would work. |
84 | // |
85 | // Note that we also provide the default constructor `SomeOptions() {}`, so that |
86 | // `SomeOptions options = {}` can work. |
87 | #define TORCH_OPTIONS_CTOR_VARIANT_ARG3( \ |
88 | OPTIONS_NAME, ARG_NAME, TYPE1, TYPE2, TYPE3) \ |
89 | OPTIONS_NAME() = default; \ |
90 | OPTIONS_NAME(torch::enumtype::TYPE1 ARG_NAME) : ARG_NAME##_(torch::TYPE1) {} \ |
91 | OPTIONS_NAME(torch::enumtype::TYPE2 ARG_NAME) : ARG_NAME##_(torch::TYPE2) {} \ |
92 | OPTIONS_NAME(torch::enumtype::TYPE3 ARG_NAME) : ARG_NAME##_(torch::TYPE3) {} |
93 | |
94 | #define TORCH_OPTIONS_CTOR_VARIANT_ARG4( \ |
95 | OPTIONS_NAME, ARG_NAME, TYPE1, TYPE2, TYPE3, TYPE4) \ |
96 | OPTIONS_NAME() = default; \ |
97 | OPTIONS_NAME(torch::enumtype::TYPE1 ARG_NAME) : ARG_NAME##_(torch::TYPE1) {} \ |
98 | OPTIONS_NAME(torch::enumtype::TYPE2 ARG_NAME) : ARG_NAME##_(torch::TYPE2) {} \ |
99 | OPTIONS_NAME(torch::enumtype::TYPE3 ARG_NAME) : ARG_NAME##_(torch::TYPE3) {} \ |
100 | OPTIONS_NAME(torch::enumtype::TYPE4 ARG_NAME) : ARG_NAME##_(torch::TYPE4) {} |
101 | |
102 | TORCH_ENUM_DECLARE(Linear) |
103 | TORCH_ENUM_DECLARE(Conv1D) |
104 | TORCH_ENUM_DECLARE(Conv2D) |
105 | TORCH_ENUM_DECLARE(Conv3D) |
106 | TORCH_ENUM_DECLARE(ConvTranspose1D) |
107 | TORCH_ENUM_DECLARE(ConvTranspose2D) |
108 | TORCH_ENUM_DECLARE(ConvTranspose3D) |
109 | TORCH_ENUM_DECLARE(Sigmoid) |
110 | TORCH_ENUM_DECLARE(Tanh) |
111 | TORCH_ENUM_DECLARE(ReLU) |
112 | TORCH_ENUM_DECLARE(GELU) |
113 | TORCH_ENUM_DECLARE(SiLU) |
114 | TORCH_ENUM_DECLARE(Mish) |
115 | TORCH_ENUM_DECLARE(LeakyReLU) |
116 | TORCH_ENUM_DECLARE(FanIn) |
117 | TORCH_ENUM_DECLARE(FanOut) |
118 | TORCH_ENUM_DECLARE(Constant) |
119 | TORCH_ENUM_DECLARE(Reflect) |
120 | TORCH_ENUM_DECLARE(Replicate) |
121 | TORCH_ENUM_DECLARE(Circular) |
122 | TORCH_ENUM_DECLARE(Nearest) |
123 | TORCH_ENUM_DECLARE(Bilinear) |
124 | TORCH_ENUM_DECLARE(Bicubic) |
125 | TORCH_ENUM_DECLARE(Trilinear) |
126 | TORCH_ENUM_DECLARE(Area) |
127 | TORCH_ENUM_DECLARE(NearestExact) |
128 | TORCH_ENUM_DECLARE(Sum) |
129 | TORCH_ENUM_DECLARE(Mean) |
130 | TORCH_ENUM_DECLARE(Max) |
131 | TORCH_ENUM_DECLARE(None) |
132 | TORCH_ENUM_DECLARE(BatchMean) |
133 | TORCH_ENUM_DECLARE(Zeros) |
134 | TORCH_ENUM_DECLARE(Border) |
135 | TORCH_ENUM_DECLARE(Reflection) |
136 | TORCH_ENUM_DECLARE(RNN_TANH) |
137 | TORCH_ENUM_DECLARE(RNN_RELU) |
138 | TORCH_ENUM_DECLARE(LSTM) |
139 | TORCH_ENUM_DECLARE(GRU) |
140 | TORCH_ENUM_DECLARE(Valid) |
141 | TORCH_ENUM_DECLARE(Same) |
142 | |
143 | namespace torch { |
144 | namespace enumtype { |
145 | |
146 | struct _compute_enum_name { |
147 | TORCH_ENUM_PRETTY_PRINT(Linear) |
148 | TORCH_ENUM_PRETTY_PRINT(Conv1D) |
149 | TORCH_ENUM_PRETTY_PRINT(Conv2D) |
150 | TORCH_ENUM_PRETTY_PRINT(Conv3D) |
151 | TORCH_ENUM_PRETTY_PRINT(ConvTranspose1D) |
152 | TORCH_ENUM_PRETTY_PRINT(ConvTranspose2D) |
153 | TORCH_ENUM_PRETTY_PRINT(ConvTranspose3D) |
154 | TORCH_ENUM_PRETTY_PRINT(Sigmoid) |
155 | TORCH_ENUM_PRETTY_PRINT(Tanh) |
156 | TORCH_ENUM_PRETTY_PRINT(ReLU) |
157 | TORCH_ENUM_PRETTY_PRINT(GELU) |
158 | TORCH_ENUM_PRETTY_PRINT(SiLU) |
159 | TORCH_ENUM_PRETTY_PRINT(Mish) |
160 | TORCH_ENUM_PRETTY_PRINT(LeakyReLU) |
161 | TORCH_ENUM_PRETTY_PRINT(FanIn) |
162 | TORCH_ENUM_PRETTY_PRINT(FanOut) |
163 | TORCH_ENUM_PRETTY_PRINT(Constant) |
164 | TORCH_ENUM_PRETTY_PRINT(Reflect) |
165 | TORCH_ENUM_PRETTY_PRINT(Replicate) |
166 | TORCH_ENUM_PRETTY_PRINT(Circular) |
167 | TORCH_ENUM_PRETTY_PRINT(Nearest) |
168 | TORCH_ENUM_PRETTY_PRINT(Bilinear) |
169 | TORCH_ENUM_PRETTY_PRINT(Bicubic) |
170 | TORCH_ENUM_PRETTY_PRINT(Trilinear) |
171 | TORCH_ENUM_PRETTY_PRINT(Area) |
172 | TORCH_ENUM_PRETTY_PRINT(NearestExact) |
173 | TORCH_ENUM_PRETTY_PRINT(Sum) |
174 | TORCH_ENUM_PRETTY_PRINT(Mean) |
175 | TORCH_ENUM_PRETTY_PRINT(Max) |
176 | TORCH_ENUM_PRETTY_PRINT(None) |
177 | TORCH_ENUM_PRETTY_PRINT(BatchMean) |
178 | TORCH_ENUM_PRETTY_PRINT(Zeros) |
179 | TORCH_ENUM_PRETTY_PRINT(Border) |
180 | TORCH_ENUM_PRETTY_PRINT(Reflection) |
181 | TORCH_ENUM_PRETTY_PRINT(RNN_TANH) |
182 | TORCH_ENUM_PRETTY_PRINT(RNN_RELU) |
183 | TORCH_ENUM_PRETTY_PRINT(LSTM) |
184 | TORCH_ENUM_PRETTY_PRINT(GRU) |
185 | TORCH_ENUM_PRETTY_PRINT(Valid) |
186 | TORCH_ENUM_PRETTY_PRINT(Same) |
187 | }; |
188 | |
189 | template <typename V> |
190 | std::string get_enum_name(V variant_enum) { |
191 | return c10::visit(enumtype::_compute_enum_name{}, variant_enum); |
192 | } |
193 | |
194 | template <typename V> |
195 | at::Reduction::Reduction reduction_get_enum(V variant_enum) { |
196 | if (c10::get_if<enumtype::kNone>(&variant_enum)) { |
197 | return at::Reduction::None; |
198 | } else if (c10::get_if<enumtype::kMean>(&variant_enum)) { |
199 | return at::Reduction::Mean; |
200 | } else if (c10::get_if<enumtype::kSum>(&variant_enum)) { |
201 | return at::Reduction::Sum; |
202 | } else { |
203 | TORCH_CHECK( |
204 | false, |
205 | get_enum_name(variant_enum), |
206 | " is not a valid value for reduction" ); |
207 | return at::Reduction::END; |
208 | } |
209 | } |
210 | |
211 | } // namespace enumtype |
212 | } // namespace torch |
213 | |