1 | #pragma once |
2 | |
3 | #include <ATen/core/boxing/OperatorKernel.h> |
4 | #include <ATen/core/ivalue.h> |
5 | #include <ATen/core/stack.h> |
6 | #include <c10/util/TypeList.h> |
7 | #include <ATen/core/IListRef.h> |
8 | #include <c10/util/intrusive_ptr.h> |
9 | #include <c10/util/Metaprogramming.h> |
10 | |
11 | #include <utility> |
12 | |
13 | namespace c10 { |
14 | |
15 | using Stack = torch::jit::Stack; // TODO Instead of this, move torch::jit::Stack to the c10 namespace. |
16 | class OperatorHandle; |
17 | |
18 | /* |
19 | * [Note: Argument forwarding in the dispatcher] |
20 | * |
21 | * The dispatcher uses a somewhat unusual way to forward arguments through several layers of |
22 | * wrapper functions. This can be confusing because an experienced C++ programmer would look at this |
23 | * and think "oh this is supposed to be forwarding a universal reference but the && is missing. This is a bug.". |
24 | * It is not a bug. The common way in C++ to forward arguments is to use universal references: |
25 | * |
26 | * > template<class T> void func(T&& arg) { func2(std::forward<T>(arg)); } |
27 | * |
28 | * but that relies on inferring the correct reference type (i.e. value vs & vs &&) from the argument. |
29 | * In our case, we cannot rely on the argument as supplied by the caller, because that could infer a |
30 | * different reference type than was used in the kernel function. The correct reference type |
31 | * is dictated by the kernel signature and must be identical since we cast function pointers |
32 | * through void* pointers and mismatches would be UB. So we need a forwarding pattern that determines |
33 | * the reference type to use by looking at the explicitly supplied operator signature, not by looking at |
34 | * the argument we're calling it with. |
35 | * |
36 | * What does std::forward do, exactly? |
37 | * ------------------------------------ |
38 | * std::forward<T>(t) is a way to cast t to the reference type supplied in T. |
39 | * Let's assume decay_t<T> == U and T is either U or some reference of U. |
40 | * - std::forward<T&>(t) will return U&, no matter what kind of reference t is. |
41 | * - std::forward<T&&>(t) will return U&&, no matter what kind of reference t is. |
42 | * - std::forward<T>(t) will return U&& (not U!), no matter what kind of reference t is. |
43 | * |
44 | * For universal references, that means that in the following function |
45 | * > template<class T> void func(T&& arg) { func2(std::forward<T>(arg)); } |
46 | * |
47 | * - when called with arg being a rvalue reference or non-reference value, T gets inferred to be |
48 | * a non-reference U, and std::forward<T>(t) will return U&&, correctly moving the argument. |
49 | * - when called with arg behind a lvalue reference, T gets inferred to be U& because that's the only |
50 | * way to match the signature (in C++, a type that is (T&)&& will collapse to T&). |
51 | * That means std::forward<T>(t) will return U& and the value will not be moved but passed on as |
52 | * a lvalue reference. |
53 | * |
54 | * How do we use that? |
55 | * ------------------------------------ |
56 | * But std::forward can also be used outside of the common "universal forwarding" pattern to change |
57 | * reference types. So instead of following the common C++ pattern, we notice what |
58 | * std::forward<T>() actually does, and that is it takes a value and changes its reference to the |
59 | * type of reference passed in as T. If we don't infer T but explicitly specify it, we can use this |
60 | * to forward based on an explicitly specified reference type instead of the inferred argument type. |
61 | * |
62 | * This is why many of the dispatcher functions look like |
63 | * > template<class T> func(T t) { func2<T>(std::forward<T>(t)); } |
64 | * instead of the common |
65 | * > template<class T> func(T&& t) { func2(std::forward<T>(t)); } |
66 | * |
67 | * and are expected to be called by explicitly specifying the template parameters in a way that matches |
68 | * the expected operator signature at each call site. |
69 | */ |
70 | |
71 | namespace impl { |
72 | // supported_primitive_arg_types defines which primitive types we allow in |
73 | // kernel functions as arguments or returns. |
74 | // Additionally, we support lists, dicts and optionals containing these types. |
75 | using supported_primitive_arg_types = guts::typelist::typelist< |
76 | int64_t, |
77 | double, |
78 | bool, |
79 | c10::string_view, |
80 | at::Tensor, |
81 | at::Scalar, |
82 | c10::QScheme, |
83 | c10::ScalarType, |
84 | c10::Device, |
85 | c10::Layout, |
86 | c10::MemoryFormat, |
87 | at::Dimname |
88 | >; |
89 | |
90 | // We have an unboxed functor in hand that takes C++ arguments, and |
91 | // we're building a boxed functor wrapper for it that takes IValues. |
92 | // So "outside" is boxed and "inside" is unboxed. |
93 | // |
94 | // So a valid input type is one that our boxed functor wrapper can |
95 | // unbox from an IValue into a C++ value. |
96 | // |
97 | // Whereas a valid output type is one that our wrapper can recieve |
98 | // as a C++ value from the unboxed functor, and box into an IValue. |
99 | |
100 | // |
101 | // assert_is_valid_input_type |
102 | // checks that T can be unboxed from an IValue into a C++ value. |
103 | // |
104 | |
105 | template<class T, bool AllowDeprecatedTypes, class Enable = void> |
106 | struct assert_is_valid_input_type { |
107 | assert_is_valid_input_type() { |
108 | guts::if_constexpr<guts::typelist::contains<supported_primitive_arg_types, T>::value>([] { |
109 | /* everything is ok, this is a primitive type */ |
110 | }, /* else */ [] { |
111 | /* otherwise this must be an instance of a valid custom class, since it can only |
112 | have been created via IValue(x), which ensures this. */ |
113 | }); |
114 | } |
115 | }; |
116 | |
117 | template<class T, bool AllowDeprecatedTypes> |
118 | struct assert_is_valid_input_type<c10::optional<T>, AllowDeprecatedTypes> |
119 | : assert_is_valid_input_type<T, AllowDeprecatedTypes> {}; |
120 | |
121 | template <bool AllowDeprecatedTypes, class... Args> |
122 | struct TypeCheckHelper; |
123 | |
124 | template <bool AllowDeprecatedTypes> |
125 | struct TypeCheckHelper<AllowDeprecatedTypes> {}; |
126 | |
127 | template <bool AllowDeprecatedTypes, class Head, class... Rest> |
128 | struct TypeCheckHelper<AllowDeprecatedTypes, Head, Rest...> |
129 | : TypeCheckHelper<AllowDeprecatedTypes, Rest...> { |
130 | assert_is_valid_input_type<Head, AllowDeprecatedTypes> check; |
131 | }; |
132 | |
133 | template<class... Contained, bool AllowDeprecatedTypes> |
134 | struct assert_is_valid_input_type<std::tuple<Contained...>, AllowDeprecatedTypes> |
135 | : TypeCheckHelper<AllowDeprecatedTypes, Contained...> {}; |
136 | |
137 | template<class Key, class Value, bool AllowDeprecatedTypes> |
138 | struct assert_is_valid_input_type<Dict<Key, Value>, AllowDeprecatedTypes> |
139 | : assert_is_valid_input_type<Value, AllowDeprecatedTypes> { |
140 | static_assert(guts::typelist::contains<impl::valid_dict_key_types, Key>::value, |
141 | "You tried to register a kernel with an unsupported input type: Dict<Key, Value> where Key is invalid. We only support int64_t, double, bool, and string." ); |
142 | }; |
143 | |
144 | template<class Key, class Value, bool AllowDeprecatedTypes> |
145 | struct assert_is_valid_input_type<std::unordered_map<Key, Value>, AllowDeprecatedTypes> |
146 | : assert_is_valid_input_type<Value, AllowDeprecatedTypes> { |
147 | static_assert(AllowDeprecatedTypes, |
148 | "You tried to register a kernel with an unsupported input type: std::unordered_map<Key, Value>. Please use Dict<Key, Value> instead." ); |
149 | static_assert(guts::typelist::contains<impl::valid_dict_key_types, Key>::value, |
150 | "You tried to register a kernel with an unsupported input type: std::unordered_map<Key, Value> where Key is invalid. We only support int64_t, double, bool, and string." ); |
151 | }; |
152 | |
153 | template<class T, bool AllowDeprecatedTypes> |
154 | struct assert_is_valid_input_type<List<T>, AllowDeprecatedTypes> |
155 | : assert_is_valid_input_type<T, AllowDeprecatedTypes> { |
156 | static_assert(!std::is_same<T, at::Scalar>::value, |
157 | "You tried to register a kernel with an unsupported input type: List<Scalar>. Please use List<int64_t>, List<double> or Tensor instead." ); |
158 | }; |
159 | |
160 | template<class T, bool AllowDeprecatedTypes> |
161 | struct assert_is_valid_input_type<c10::ArrayRef<T>, AllowDeprecatedTypes> |
162 | : assert_is_valid_input_type<T, AllowDeprecatedTypes> { |
163 | static_assert(!std::is_same<T, at::Scalar>::value, |
164 | "You tried to register a kernel with an unsupported input type: ArrayRef<Scalar>. Please use List<int64_t>, List<double> or Tensor instead." ); |
165 | }; |
166 | |
167 | template<class T, bool AllowDeprecatedTypes> |
168 | struct assert_is_valid_input_type<c10::OptionalArrayRef<T>, AllowDeprecatedTypes> |
169 | : assert_is_valid_input_type<T, AllowDeprecatedTypes> { |
170 | static_assert(!std::is_same<T, at::Scalar>::value, |
171 | "You tried to register a kernel with an unsupported input type: OptionalArrayRef<Scalar>. Please use List<int64_t>, List<double> or Tensor instead." ); |
172 | }; |
173 | |
174 | template<class T, size_t N, bool AllowDeprecatedTypes> |
175 | struct assert_is_valid_input_type<std::array<T, N>, AllowDeprecatedTypes> |
176 | : assert_is_valid_input_type<T, AllowDeprecatedTypes> { |
177 | static_assert(!std::is_same<T, at::Scalar>::value, |
178 | "You tried to register a kernel with an unsupported input type: std::array<Scalar, N>. Please use std::array<int64_t, N> instead." ); |
179 | }; |
180 | |
181 | // The following specialisations of assert_is_valid_input_type are technically not |
182 | // necessary since we would hit the base case and show an error message |
183 | // there if they didn't exist, but we can show a better error message |
184 | // in some common error scenarios. |
185 | template<class T, bool AllowDeprecatedTypes> |
186 | struct assert_is_valid_input_type<T, AllowDeprecatedTypes, std::enable_if_t<std::is_same<float, T>::value>> { |
187 | // There is no reason to support float when we have double. Keep the API lean. |
188 | static_assert(guts::false_t<T>::value, |
189 | "You tried to register a kernel with an unsupported input type: float. Please use double instead." ); |
190 | }; |
191 | template<class T, bool AllowDeprecatedTypes> |
192 | struct assert_is_valid_input_type<T, AllowDeprecatedTypes, std::enable_if_t<std::is_same<const char*, T>::value>> { |
193 | static_assert(guts::false_t<T>::value, |
194 | "You tried to register a kernel with an unsupported input type: const char*. Please use c10::string_view instead." ); |
195 | }; |
196 | template<class T, bool AllowDeprecatedTypes> |
197 | struct assert_is_valid_input_type<T, AllowDeprecatedTypes, std::enable_if_t<std::is_same<std::vector<bool>, T>::value>> { |
198 | static_assert(guts::false_t<T>::value, |
199 | "You tried to register a kernel with an unsupported input type: vector<bool>. Please use List<bool> instead." ); |
200 | }; |
201 | template<class T, bool AllowDeprecatedTypes> |
202 | struct assert_is_valid_input_type<T, AllowDeprecatedTypes, std::enable_if_t<std::is_integral<T>::value && !guts::typelist::contains<supported_primitive_arg_types, T>::value>> { |
203 | static_assert(guts::false_t<T>::value, |
204 | "You tried to register a kernel with an unsupported integral input type. Please use int64_t instead." ); |
205 | }; |
206 | |
207 | // |
208 | // assert_is_valid_output_type |
209 | // |
210 | |
211 | template<class T, bool AllowDeprecatedTypes, class Enable = void> |
212 | struct assert_is_valid_output_type { |
213 | assert_is_valid_output_type() { |
214 | guts::if_constexpr<guts::typelist::contains<supported_primitive_arg_types, T>::value>([] { |
215 | /* everything is ok, this is a primitive type */ |
216 | }, /* else */ [] { |
217 | /* otherwise T is verified to be a registered custom class in the IValue |
218 | constructor, so no benefit in double-checking here */ |
219 | }); |
220 | } |
221 | }; |
222 | |
223 | template<class T, bool AllowDeprecatedTypes> |
224 | struct assert_is_valid_output_type<c10::optional<T>, AllowDeprecatedTypes> |
225 | : assert_is_valid_output_type<T, AllowDeprecatedTypes> {}; |
226 | |
227 | template<class T, bool AllowDeprecatedTypes> |
228 | struct assert_is_valid_output_type<c10::OptionalArrayRef<T>, AllowDeprecatedTypes> |
229 | : assert_is_valid_output_type<T, AllowDeprecatedTypes> {}; |
230 | |
231 | template<class Key, class Value, bool AllowDeprecatedTypes> |
232 | struct assert_is_valid_output_type<Dict<Key, Value>, AllowDeprecatedTypes> |
233 | : assert_is_valid_output_type<Value, AllowDeprecatedTypes> { |
234 | static_assert(guts::typelist::contains<impl::valid_dict_key_types, Key>::value, |
235 | "You tried to register a kernel with an unsupported output type: Dict<Key, Value> where Key is invalid. We only support int64_t, double, bool, and string." ); |
236 | static_assert(!std::is_same<Value, at::Scalar>::value, |
237 | "You tried to register a kernel with an unsupported output type: Dict<Key, Scalar>. Please use Dict<Key, int64_t> or Dict<Key, double>." ); |
238 | }; |
239 | |
240 | template<class Key, class Value, bool AllowDeprecatedTypes> |
241 | struct assert_is_valid_output_type<std::unordered_map<Key, Value>, AllowDeprecatedTypes> |
242 | : assert_is_valid_output_type<Value, AllowDeprecatedTypes> { |
243 | static_assert(AllowDeprecatedTypes, |
244 | "You tried to register a kernel with an unsupported output type: std::unordered_map<Key, Value>. Please use Dict<Key, Value> instead." ); |
245 | static_assert(guts::typelist::contains<impl::valid_dict_key_types, Key>::value, |
246 | "You tried to register a kernel with an unsupported output type: std::unordered_map<Key, Value> where Key is invalid. We only support int64_t, double, bool, and string." ); |
247 | static_assert(!std::is_same<Value, at::Scalar>::value, |
248 | "You tried to register a kernel with an unsupported output type: std::unordered_map<Key, Scalar>. Please use Dict<Key, int64_t> or Dict<Key, double>." ); |
249 | }; |
250 | |
251 | template<class T, bool AllowDeprecatedTypes> |
252 | struct assert_is_valid_output_type<List<T>, AllowDeprecatedTypes> |
253 | : assert_is_valid_output_type<T, AllowDeprecatedTypes> { |
254 | static_assert(!std::is_same<T, at::Scalar>::value, |
255 | "You tried to register a kernel with an unsupported output type: List<Scalar>. Please use List<int64_t>, List<double> or Tensor instead." ); |
256 | }; |
257 | |
258 | template<class T, bool AllowDeprecatedTypes> |
259 | struct assert_is_valid_output_type<std::vector<T>, AllowDeprecatedTypes> |
260 | : assert_is_valid_output_type<T, AllowDeprecatedTypes> { |
261 | static_assert(!std::is_same<T, at::Scalar>::value, |
262 | "You tried to register a kernel with an unsupported output type: std::vector<Scalar>. Please use List<int64_t>, List<double> or Tensor instead." ); |
263 | // TODO static_assert(AllowDeprecatedTypes, "You tried to register a kernel with an unsupported output type: std::vector<T>. Please use List<T> instead."); |
264 | }; |
265 | |
266 | template<class T, size_t N, bool AllowDeprecatedTypes> |
267 | struct assert_is_valid_output_type<std::array<T, N>, AllowDeprecatedTypes> |
268 | : assert_is_valid_output_type<T, AllowDeprecatedTypes> { |
269 | static_assert(!std::is_same<T, at::Scalar>::value, |
270 | "You tried to register a kernel with an unsupported output type: std::array<Scalar, N>. Please use std::array<int64_t, N> instead." ); |
271 | }; |
272 | |
273 | // The following specialisations of assert_is_valid_output_type are technically not |
274 | // necessary since we would hit the base case and show an error message |
275 | // there if they didn't exist, but we can show a better error message |
276 | // in some common error scenarios. |
277 | template<class T, bool AllowDeprecatedTypes> |
278 | struct assert_is_valid_output_type<T, AllowDeprecatedTypes, std::enable_if_t<std::is_same<float, T>::value>> { |
279 | // There is no reason to support float when we have double. Keep the API lean. |
280 | static_assert(guts::false_t<T>::value, |
281 | "You tried to register a kernel with an unsupported output type: float. Please use double instead." ); |
282 | }; |
283 | template<class T, bool AllowDeprecatedTypes> |
284 | struct assert_is_valid_output_type<T, AllowDeprecatedTypes, std::enable_if_t<std::is_same<const char*, T>::value>> { |
285 | static_assert(guts::false_t<T>::value, |
286 | "You tried to register a kernel with an unsupported output type: const char*. Please use c10::string_view instead." ); |
287 | }; |
288 | template<class T, bool AllowDeprecatedTypes> |
289 | struct assert_is_valid_output_type<T, AllowDeprecatedTypes, std::enable_if_t<std::is_same<std::vector<bool>, T>::value>> { |
290 | static_assert(guts::false_t<T>::value, |
291 | "You tried to register a kernel with an unsupported output type: vector<bool>. Please use List<bool> instead." ); |
292 | }; |
293 | template<class T, bool AllowDeprecatedTypes> |
294 | struct assert_is_valid_output_type<T, AllowDeprecatedTypes, std::enable_if_t<std::is_integral<T>::value && !guts::typelist::contains<supported_primitive_arg_types, T>::value>> { |
295 | static_assert(guts::false_t<T>::value, |
296 | "You tried to register a kernel with an unsupported integral output type. Please use int64_t instead." ); |
297 | }; |
298 | |
299 | // ivalue_to_arg |
300 | |
301 | template<class T> |
302 | struct decay_if_not_tensor final { |
303 | using type = std::decay_t<T>; |
304 | }; |
305 | |
306 | template<> |
307 | struct decay_if_not_tensor<at::Tensor&> final { |
308 | using type = at::Tensor&; |
309 | }; |
310 | |
311 | template<> |
312 | struct decay_if_not_tensor<const at::Tensor&> final { |
313 | using type = const at::Tensor&; |
314 | }; |
315 | |
316 | template<class T, bool AllowDeprecatedTypes> |
317 | struct ivalue_to_arg final { |
318 | static decltype(auto) call(IValue& v) { |
319 | assert_is_valid_input_type<T, AllowDeprecatedTypes>(); |
320 | return std::move(v).to<T>(); |
321 | } |
322 | }; |
323 | |
324 | // The following two specializations take advantage of specialized |
325 | // `toTensor()` overloads on IValue to avoid copying. |
326 | template<bool AllowDeprecatedTypes> |
327 | struct ivalue_to_arg<at::Tensor&, AllowDeprecatedTypes> final { |
328 | // We cannot use the default implementation if they asked for a |
329 | // `at::Tensor&` because it moves from the IValue, so it can't get |
330 | // an lvalue reference. |
331 | static at::Tensor& call(IValue& v) { |
332 | // Tensor& is valid, don't bother asserting |
333 | return v.toTensor(); |
334 | } |
335 | }; |
336 | |
337 | template<bool AllowDeprecatedTypes> |
338 | struct ivalue_to_arg<const at::Tensor&, AllowDeprecatedTypes> final { |
339 | // We should not use the default implementation if they asked for |
340 | // a `const at::Tensor&` because it moves from the IValue and they |
341 | // didn't ask for that. |
342 | static const at::Tensor& call(IValue& v) { |
343 | // const Tensor& is valid, don't bother asserting |
344 | return v.toTensor(); |
345 | } |
346 | }; |
347 | |
348 | template<bool AllowDeprecatedTypes> |
349 | struct ivalue_to_arg<at::ITensorListRef, AllowDeprecatedTypes> final { |
350 | static List<at::Tensor> call(IValue& v) { |
351 | return v.toTensorList(); |
352 | } |
353 | }; |
354 | |
355 | template<class T, bool AllowDeprecatedTypes> |
356 | struct ivalue_to_arg<ArrayRef<T>, AllowDeprecatedTypes> final { |
357 | // If an argument is ArrayRef<T>, convert the IValue to a std::vector<T> and pass that |
358 | // to the operator. std::vector<T> is implicitly convertible to ArrayRef<T>. |
359 | static std::vector<T> call(IValue& v) { |
360 | return ivalue_to_arg<std::vector<T>, AllowDeprecatedTypes>::call(v); |
361 | } |
362 | }; |
363 | template<bool AllowDeprecatedTypes> |
364 | struct ivalue_to_arg<c10::SymIntArrayRef, AllowDeprecatedTypes> final { |
365 | static std::vector<c10::SymInt> call(IValue& v) { |
366 | if (v.isIntList()) { |
367 | std::vector<c10::SymInt> r; |
368 | auto src = v.toIntList(); |
369 | std::transform(src.begin(), src.end(), std::back_inserter(r), [](int64_t i) { return c10::SymInt(i); }); |
370 | return r; |
371 | } else { |
372 | return ivalue_to_arg<std::vector<c10::SymInt>, AllowDeprecatedTypes>::call(v); |
373 | } |
374 | } |
375 | }; |
376 | template<bool AllowDeprecatedTypes> |
377 | struct ivalue_to_arg<c10::OptionalArray<c10::SymInt>, AllowDeprecatedTypes> final { |
378 | static OptionalArray<c10::SymInt> call(IValue& v) { |
379 | if (v.isIntList()) { |
380 | std::vector<c10::SymInt> r; |
381 | auto src = v.toIntList(); |
382 | std::transform(src.begin(), src.end(), std::back_inserter(r), [](int64_t i) { return c10::SymInt(i); }); |
383 | return OptionalArray<c10::SymInt>(std::move(r)); |
384 | } else { |
385 | return std::move(v).to<OptionalArray<c10::SymInt>>(); |
386 | } |
387 | } |
388 | }; |
389 | template<class T, bool AllowDeprecatedTypes> |
390 | struct ivalue_to_arg<optional<ArrayRef<T>>, AllowDeprecatedTypes> final { |
391 | // If an argument is optional<ArrayRef<T>>, convert the IValue to an optional<std::vector<T>> and pass that |
392 | // to the operator. OptionalArray<T> is basically a optional<std::vector<T>> but implicitly convertible |
393 | // to optional<ArrayRef<T>>. |
394 | static OptionalArray<T> call(IValue& v) { |
395 | return ivalue_to_arg<OptionalArray<T>, AllowDeprecatedTypes>::call(v); |
396 | } |
397 | }; |
398 | |
399 | template<class T, bool AllowDeprecatedTypes> |
400 | struct ivalue_to_arg<OptionalArrayRef<T>, AllowDeprecatedTypes> final { |
401 | // If an argument is OptionalArrayRef<T>, convert the IValue to an |
402 | // optional<std::vector<T>> and pass that to the operator. OptionalArray<T> |
403 | // is basically a optional<std::vector<T>> but implicitly convertible to |
404 | // OptionalArrayRef<T> |
405 | static OptionalArray<T> call(IValue& v) { |
406 | return ivalue_to_arg<OptionalArray<T>, AllowDeprecatedTypes>::call(v); |
407 | } |
408 | }; |
409 | |
410 | // return_to_ivalue |
411 | template<class T, bool AllowDeprecatedTypes, class Enable = void> |
412 | struct return_to_ivalue final {}; |
413 | |
414 | template<class T, bool AllowDeprecatedTypes> |
415 | struct return_to_ivalue<T, AllowDeprecatedTypes, std::enable_if_t<!std::is_same<at::Tensor&, T>::value>> final { |
416 | static IValue call(T&& v) { |
417 | assert_is_valid_output_type<T, AllowDeprecatedTypes>(); |
418 | return c10::ivalue::from(std::move(v)); |
419 | } |
420 | static IValue copy(const T& v) { |
421 | assert_is_valid_output_type<T, AllowDeprecatedTypes>(); |
422 | return IValue(v); |
423 | } |
424 | }; |
425 | |
426 | // Special case to allow kernels to return `Tensor&`. |
427 | // TODO Delete this once kernels don't do that anymore |
428 | template<bool AllowDeprecatedTypes> |
429 | struct return_to_ivalue<at::Tensor&, AllowDeprecatedTypes, void> final { |
430 | static IValue call(at::Tensor& v) { |
431 | return c10::ivalue::from(v); |
432 | } |
433 | static IValue copy(at::Tensor& v) { |
434 | return IValue(v); |
435 | } |
436 | }; |
437 | |
438 | // wrap_kernel_functor_unboxed_ |
439 | |
440 | template<class KernelFunctor, class OpSignature> |
441 | struct wrap_kernel_functor_unboxed_ final {}; |
442 | |
443 | // This specialization is for kernels with a first argument that is NOT of type DispatchKeySet |
444 | // This includes kernels with 0 arguments. |
445 | template<class KernelFunctor, class ReturnType, class... ParameterTypes> |
446 | struct wrap_kernel_functor_unboxed_<KernelFunctor, ReturnType(ParameterTypes...)> final { |
447 | static_assert(std::is_same<ReturnType, typename guts::infer_function_traits_t<KernelFunctor>::return_type>::value, |
448 | "Return type mismatch" ); |
449 | static_assert(std::is_same<guts::typelist::typelist<ParameterTypes...>, typename guts::infer_function_traits_t<KernelFunctor>::parameter_types>::value, |
450 | "Parameter types mismatch" ); |
451 | |
452 | // See [Note: Argument forwarding in the dispatcher] for why ParameterTypes doesn't use && |
453 | static ReturnType call(OperatorKernel* functor, DispatchKeySet, ParameterTypes... args) { |
454 | KernelFunctor* functor_ = static_cast<KernelFunctor*>(functor); |
455 | // Note [Plumbing Keys Through The Dispatcher 2] |
456 | // See Note [Plumbing Keys Through The Dispatcher] for the background. |
457 | // This functor explicitly takes in a dispatchKeySet and drops it on the floor- it does not forward it to the registered kernel. |
458 | // |
459 | // This is due to the calling convention within the dispatcher, which expects all registered kernels to have a first argument of type |
460 | // DispatchKeySet. |
461 | // This is not the case for pretty much all manually written kernels, however- this functor serves to separate the calling convention |
462 | // of the dispatcher from the calling convention of manually written kernels. |
463 | return (*functor_)(std::forward<ParameterTypes>(args)...); |
464 | } |
465 | }; |
466 | |
467 | // This specialization is for kernels with a first argument of type DispatchKeySet |
468 | template<class KernelFunctor, class ReturnType, class... ParameterTypes> |
469 | struct wrap_kernel_functor_unboxed_<KernelFunctor, ReturnType(DispatchKeySet, ParameterTypes...)> final { |
470 | static_assert(std::is_same<ReturnType, typename guts::infer_function_traits_t<KernelFunctor>::return_type>::value, |
471 | "Return type mismatch" ); |
472 | static_assert(std::is_same<guts::typelist::typelist<DispatchKeySet, ParameterTypes...>, typename guts::infer_function_traits_t<KernelFunctor>::parameter_types>::value, |
473 | "Parameter types mismatch" ); |
474 | |
475 | // See [Note: Argument forwarding in the dispatcher] for why ParameterTypes doesn't use && |
476 | static ReturnType call(OperatorKernel* functor, DispatchKeySet dispatchKeySet, ParameterTypes... args) { |
477 | KernelFunctor* functor_ = static_cast<KernelFunctor*>(functor); |
478 | // We're explicitly taking in a dispatchKeySet and forwarding it to the registered kernel. |
479 | // See Note [Plumbing Keys Through The Dispatcher 2] for details. |
480 | return (*functor_)(dispatchKeySet, std::forward<ParameterTypes>(args)...); |
481 | } |
482 | }; |
483 | |
484 | template<class KernelFunctor> |
485 | using wrap_kernel_functor_unboxed = wrap_kernel_functor_unboxed_<KernelFunctor, typename guts::infer_function_traits_t<KernelFunctor>::func_type>; |
486 | |
487 | // call_functor_with_args_from_stack |
488 | |
489 | template<class Functor, bool AllowDeprecatedTypes, size_t... ivalue_arg_indices, typename... ArgTypes> |
490 | std::decay_t<typename guts::infer_function_traits_t<Functor>::return_type> |
491 | call_functor_with_args_from_stack_(OperatorKernel* functor, DispatchKeySet dispatchKeySet, Stack* stack, std::index_sequence<ivalue_arg_indices...>, guts::typelist::typelist<ArgTypes...>*) { |
492 | (void)(stack); // when sizeof...(ivalue_arg_indices) == 0, this argument would be unused and we have to silence the compiler warning. |
493 | |
494 | // We're explicitly filtering out DispatchKeySet from the argument list. |
495 | // Some kernels take a DispatchKeySet as their first argument in order to plumb keys through the dispatcher. |
496 | // We don't want to expose the DispatchKeySet type to jit, so we don't include this argument on the stack. |
497 | // See Note [Plumbing Keys Through The Dispatcher] for the background. |
498 | return wrap_kernel_functor_unboxed<Functor>::call(functor, dispatchKeySet, |
499 | ivalue_to_arg<typename decay_if_not_tensor<ArgTypes>::type, AllowDeprecatedTypes>::call( |
500 | torch::jit::peek(*stack, ivalue_arg_indices, sizeof...(ivalue_arg_indices)) |
501 | )...); |
502 | } |
503 | |
504 | template<class Functor, bool AllowDeprecatedTypes> |
505 | std::decay_t<typename guts::infer_function_traits_t<Functor>::return_type> |
506 | call_functor_with_args_from_stack(OperatorKernel* functor, DispatchKeySet dispatchKeySet, Stack* stack) { |
507 | // We're explicitly filtering out DispatchKeySet from the argument list. |
508 | // Some kernels take a DispatchKeySet as their first argument in order to plumb keys through the dispatcher. |
509 | // We don't want to expose the DispatchKeySet type to jit, so we don't include this argument on the stack. |
510 | // See Note [Plumbing Keys Through The Dispatcher] for the background. |
511 | using ArgTypes = typename c10::remove_DispatchKeySet_arg_from_func<Functor>::parameter_types; |
512 | constexpr size_t num_ivalue_args = guts::typelist::size<ArgTypes>::value; |
513 | return call_functor_with_args_from_stack_<Functor, AllowDeprecatedTypes>(functor, dispatchKeySet, stack, std::make_index_sequence<num_ivalue_args>(), static_cast<ArgTypes*>(nullptr)); |
514 | } |
515 | |
516 | // push_outputs |
517 | |
518 | template<class OutputType, bool AllowDeprecatedTypes> |
519 | struct push_outputs final { |
520 | // Contrary to [Note: Argument forwarding in the dispatcher], we use OutputType&& here |
521 | // to avoid one extra call to the move constructor in this case. This is still not a |
522 | // universal reference though because OutputType is an explicitly specified class |
523 | // template parameter. |
524 | static void call(OutputType&& output, Stack* stack) { |
525 | torch::jit::push(*stack, return_to_ivalue<OutputType, AllowDeprecatedTypes>::call(std::forward<OutputType>(output))); |
526 | } |
527 | static void copy(const OutputType& output, Stack* stack) { |
528 | torch::jit::push(*stack, return_to_ivalue<OutputType, AllowDeprecatedTypes>::copy(output)); |
529 | } |
530 | }; |
531 | template<class... OutputTypes, bool AllowDeprecatedTypes> |
532 | struct push_outputs<std::tuple<OutputTypes...>, AllowDeprecatedTypes> final { |
533 | static void call(std::tuple<OutputTypes...>&& output, Stack* stack) { |
534 | call_(std::move(output), stack, std::make_index_sequence<sizeof...(OutputTypes)>()); |
535 | } |
536 | static void copy(const std::tuple<OutputTypes...>& output, Stack* stack) { |
537 | copy_(output, stack, std::make_index_sequence<sizeof...(OutputTypes)>()); |
538 | } |
539 | |
540 | private: |
541 | template<size_t... indices> |
542 | static void call_(std::tuple<OutputTypes...>&& output, Stack* stack, std::index_sequence<indices...>) { |
543 | torch::jit::push(*stack, return_to_ivalue<OutputTypes, AllowDeprecatedTypes>::call(std::forward<OutputTypes>(std::get<indices>(output)))...); |
544 | } |
545 | template<size_t... indices> |
546 | static void copy_(const std::tuple<OutputTypes...>& output, Stack* stack, std::index_sequence<indices...>) { |
547 | torch::jit::push(*stack, return_to_ivalue<OutputTypes, AllowDeprecatedTypes>::copy(std::get<indices>(output))...); |
548 | } |
549 | }; |
550 | template<bool AllowDeprecatedTypes> |
551 | struct push_outputs<void, AllowDeprecatedTypes> final { |
552 | static void call(int /*dummy*/, Stack* /*stack*/) { |
553 | } |
554 | static void copy(int /*dummy*/, Stack* /*stack*/) { |
555 | } |
556 | }; |
557 | |
558 | // make_boxed_from_unboxed_functor |
559 | |
560 | template<class KernelFunctor, bool AllowDeprecatedTypes> |
561 | struct make_boxed_from_unboxed_functor final { |
562 | static_assert(std::is_base_of<OperatorKernel, KernelFunctor>::value, |
563 | "Tried to register a kernel functor using the kernel<Functor>() API, but it doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it." ); |
564 | |
565 | static void call(OperatorKernel* functor, const OperatorHandle&, DispatchKeySet dispatchKeySet, Stack* stack) { |
566 | using ReturnType = typename guts::infer_function_traits_t<KernelFunctor>::return_type; |
567 | // We're explicitly filtering out DispatchKeySet from the argument list. |
568 | // Some kernels take a DispatchKeySet as their first argument in order to plumb keys through the dispatcher. |
569 | // We don't want to expose the DispatchKeySet type to jit, so we don't include this argument on the stack. |
570 | // See Note [Plumbing Keys Through The Dispatcher] for the background. |
571 | using ArgTypes = typename c10::remove_DispatchKeySet_arg_from_func<KernelFunctor>::parameter_types; |
572 | constexpr bool has_outputs = !std::is_same<void, ReturnType>::value; |
573 | constexpr size_t num_inputs = guts::typelist::size<ArgTypes>::value; |
574 | #ifdef __cpp_if_constexpr |
575 | if constexpr (has_outputs) { |
576 | #else |
577 | guts::if_constexpr<has_outputs>([&] (auto delay_check) { |
578 | #endif |
579 | // Decay ReturnType to ReturnType_ so that if a reference gets returned, we actually store it by value |
580 | // and don't get a dangling reference. This is only required because some kernels still return `Tensor&`. |
581 | #ifdef __cpp_if_constexpr |
582 | // [Note: VC++ and 'std': ambiguous symbol] |
583 | using ReturnType_ = ::std::decay_t<ReturnType>; |
584 | ReturnType_ output = call_functor_with_args_from_stack<KernelFunctor, AllowDeprecatedTypes>(functor, dispatchKeySet, stack); |
585 | #else |
586 | using ReturnType_ = std::decay_t<typename decltype(delay_check)::template type_identity<ReturnType>>; |
587 | ReturnType_ output = call_functor_with_args_from_stack<KernelFunctor, AllowDeprecatedTypes>(functor, dispatchKeySet, delay_check(stack)); |
588 | #endif |
589 | torch::jit::drop(*stack, num_inputs); |
590 | // See note [ VC++ and 'std': ambiguous symbol] |
591 | push_outputs<ReturnType_, AllowDeprecatedTypes>::call(::std::move(output), stack); |
592 | #ifdef __cpp_if_constexpr |
593 | } else { |
594 | #else |
595 | }, /* else */ [&] { |
596 | #endif |
597 | call_functor_with_args_from_stack<KernelFunctor, AllowDeprecatedTypes>(functor, dispatchKeySet, stack); |
598 | torch::jit::drop(*stack, num_inputs); |
599 | #ifdef __cpp_if_constexpr |
600 | } |
601 | #else |
602 | }); |
603 | #endif |
604 | } |
605 | }; |
606 | } // namespace impl |
607 | |
608 | } // namespace c10 |
609 | |
610 | namespace torch { |
611 | using OperatorKernel = c10::OperatorKernel; |
612 | } |
613 | |