1 | #pragma once |
2 | |
3 | // This file contains boxing (not unboxing) logic, |
4 | // i.e. how to make a vector<IValue> from a set of concrete arguments. |
5 | |
6 | #include <ATen/core/ivalue.h> |
7 | #include <ATen/core/stack.h> |
8 | #include <c10/core/TensorOptions.h> |
9 | |
10 | #include <ATen/core/boxing/BoxedKernel.h> |
11 | |
12 | #include <c10/util/Metaprogramming.h> |
13 | |
14 | namespace c10 { |
15 | namespace impl { |
16 | |
17 | // |
18 | // utils |
19 | // |
20 | |
21 | // is_mutable_tensor_ref |
22 | template <class T> struct is_mutable_tensor_ref : std::false_type {}; |
23 | template <> struct is_mutable_tensor_ref<at::Tensor&> : std::true_type {}; |
24 | |
25 | // is_tuple_of_mutable_tensor_refs |
26 | // |
27 | template <class T, class Enable = void> |
28 | struct is_tuple_of_mutable_tensor_refs : std::false_type {}; |
29 | |
30 | template <class T> |
31 | struct is_tuple_of_mutable_tensor_refs<T, std::enable_if_t<guts::is_instantiation_of<std::tuple, T>::value, void>> |
32 | : guts::typelist::all<is_mutable_tensor_ref, guts::typelist::from_tuple_t<T>> |
33 | {}; |
34 | |
35 | // has_ivalue_to<T> tests the presence/absence of instance method IValue::to<T>() |
36 | // |
37 | template <class T, class Enable = void> |
38 | struct has_ivalue_to : std::false_type {}; |
39 | |
40 | template <class T> |
41 | struct has_ivalue_to<T, guts::void_t<decltype(std::declval<IValue>().to<T>())>> |
42 | : std::true_type |
43 | {}; |
44 | |
45 | // |
46 | // boxing predicates |
47 | // |
48 | |
49 | // A boxable arg type is one that IValue has a constructor for. |
50 | template <typename T> |
51 | using can_box = |
52 | guts::disjunction< |
53 | std::is_constructible<IValue, std::decay_t<T>>, |
54 | // TensorOptions are not directly constructible into IValue, |
55 | // but torch::jit::push knows how to handle them |
56 | std::is_same<TensorOptions, std::decay_t<T>> |
57 | >; |
58 | |
59 | template <typename... Ts> |
60 | using can_box_all = guts::conjunction<can_box<Ts>...>; |
61 | |
62 | // an unboxable result is one that can be extracted from an IValue |
63 | template <typename T> |
64 | using can_unbox = |
65 | guts::conjunction< |
66 | guts::disjunction< |
67 | has_ivalue_to<T>, |
68 | // void returns are ok |
69 | std::is_same<void, T> |
70 | >, |
71 | guts::negation<std::is_lvalue_reference<T>> |
72 | >; |
73 | |
74 | // |
75 | // boxArgs - utility for pushing unboxed args onto IValue stack |
76 | // |
77 | template <class... Args> |
78 | torch::jit::Stack boxArgs(Args... args) { |
79 | // TODO Reuse stack vector instead of allocating? |
80 | torch::jit::Stack stack; |
81 | stack.reserve(sizeof...(Args)); |
82 | torch::jit::push(stack, std::forward<Args>(args)...); |
83 | return stack; |
84 | } |
85 | |
86 | template <class T> |
87 | static inline constexpr size_t boxed_size_one() { |
88 | static_assert(!std::is_same<std::decay_t<T>, c10::TensorOptions>::value, "need to patch this path to support TensorOptions passed by reference" ); |
89 | return 1; |
90 | } |
91 | |
92 | // torch::jit::push pushes 4 values for a TensorOptions; this needs to |
93 | // be kept in sync. |
94 | template <> |
95 | inline constexpr size_t boxed_size_one<c10::TensorOptions>() { |
96 | return 4; |
97 | } |
98 | |
99 | // NOTE: this could probably be simplified with C++17 fold expressions. |
100 | template <typename...> |
101 | struct BoxedSize : std::integral_constant<size_t, 0> {}; |
102 | template <class T, class... Args> |
103 | struct BoxedSize<T, Args...> : std::integral_constant<size_t, boxed_size_one<T>() + BoxedSize<Args...>::value> {}; |
104 | |
105 | template <class... Args> |
106 | static inline constexpr size_t boxed_size() { |
107 | return BoxedSize<Args...>::value; |
108 | } |
109 | |
110 | using IValueAlignedStorage = std::aligned_storage_t<sizeof(IValue), alignof(IValue)>; |
111 | |
112 | template <typename T> |
113 | C10_ALWAYS_INLINE_UNLESS_MOBILE void boxToStack(IValueAlignedStorage* dest, T& arg, int& lastIdx) { |
114 | new (&dest[lastIdx]) IValue(arg); |
115 | lastIdx++; |
116 | } |
117 | |
118 | C10_ALWAYS_INLINE_UNLESS_MOBILE void boxToStack(IValueAlignedStorage* dest, c10::TensorOptions options, int& lastIdx) { |
119 | new (&dest[lastIdx++]) IValue(c10::typeMetaToScalarType(options.dtype())); |
120 | new (&dest[lastIdx++]) IValue(options.layout()); |
121 | new (&dest[lastIdx++]) IValue(options.device()); |
122 | new (&dest[lastIdx++]) IValue(options.pinned_memory()); |
123 | } |
124 | |
125 | inline void boxArgsToStack(IValueAlignedStorage*, int&) {} |
126 | |
127 | template<typename T, typename... Args> |
128 | C10_ALWAYS_INLINE_UNLESS_MOBILE void boxArgsToStack(IValueAlignedStorage* dest, int& lastIdx, T& arg, Args &... args) { |
129 | boxToStack(dest, arg, lastIdx); |
130 | boxArgsToStack(dest, lastIdx, args...); |
131 | } |
132 | |
133 | // |
134 | // PopResult is a helper class whose specializations handle popping single and |
135 | // multiple return values, respectively. |
136 | // |
137 | template <class Result> |
138 | struct PopResult final { |
139 | static Result call(Stack& stack) { |
140 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
141 | stack.size() == 1, |
142 | "Boxed kernel was expected to return one value on the stack, " , |
143 | "but instead pushed " , stack.size(), " values." |
144 | ); |
145 | return std::move(stack[0]).to<Result>(); |
146 | } |
147 | }; |
148 | |
149 | template <class... Types> |
150 | struct PopResult<std::tuple<Types...>> final { |
151 | using Result = std::tuple<Types...>; |
152 | |
153 | static Result call(Stack& stack) { |
154 | // for tuple return types, boxed kernel has pushed multiple values onto the stack |
155 | constexpr int RetCount = sizeof...(Types); |
156 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
157 | stack.size() == RetCount, |
158 | "Boxed kernel was expected to return " , RetCount, " values on the stack, " , |
159 | "but instead pushed " , stack.size(), " values." |
160 | ); |
161 | return pop_to_tuple_impl(stack, std::make_index_sequence<RetCount>()); |
162 | } |
163 | private: |
164 | // note: this has been moved into its own helper only to avoid a parse error on `indices` otherwise. |
165 | // I'm sure there's an incantation that slips it past the parser but eh |
166 | template <size_t... indices> |
167 | static Result pop_to_tuple_impl(Stack& stack, std::index_sequence<indices...>) { |
168 | return std::make_tuple((std::move(stack[indices]).to<Types>())...); |
169 | } |
170 | }; |
171 | |
172 | // |
173 | // BoxedKernelWrapper |
174 | // |
175 | // For a given function type FT, BoxedKernelWrapper<FT> implements |
176 | // a `call` method that |
177 | // - takes a boxed kernel and unboxed arguments as specified by FT, |
178 | // - calls `boxArgs` to box the arguments |
179 | // - calls the boxed kernel |
180 | // - unboxes and returns the result |
181 | // |
182 | // The partial specializations below handle various cases: in |
183 | // particular, not all types appearing in op signatures are supported, |
184 | // and ops returning references have nonstandard wrapper implementations. |
185 | // |
186 | |
187 | // 1. The base specialization of BoxedKernelWrapper should never be instantiated. |
188 | // A "no call method defined on BoxedKernelWrapper" compile error means that |
189 | // an op signature has failed to trigger any of the partial specializations |
190 | // that follow this one. |
191 | // |
192 | template <class FuncType, class Enable = void> |
193 | struct BoxedKernelWrapper { |
194 | // The reason we're not just doing straight up static_assert(false, ...) here: |
195 | // Basically, the way to make sure a static_assert only fires if a template |
196 | // is actually instantiated (rather than every time the file is parsed) is to use |
197 | // template parameters in the expression, e.g. FuncType here. However, since |
198 | // `sizeof(FuncType) != sizeof(FuncType)` is always false, this has the same |
199 | // effect. |
200 | static_assert(sizeof(FuncType) != sizeof(FuncType), |
201 | "Function signature contains one or more unsupported parameter and/or return types. " |
202 | "Look for a nearby error like " |
203 | "\"'call' is not a member of 'c10::impl::BoxedKernelWrapper<(your function type), void>'\" " |
204 | "- (your function type) is the unsupported signature." ); |
205 | }; |
206 | |
207 | // |
208 | // 2. Supported signatures, other than those involving non-const Tensor refs - |
209 | // i.e., "functional" ops. |
210 | // |
211 | |
212 | template <class Result, class... Args> |
213 | struct BoxedKernelWrapper< |
214 | Result(Args...), |
215 | std::enable_if_t< |
216 | can_box_all<Args...>::value && can_unbox<Result>::value && !is_tuple_of_mutable_tensor_refs<Result>::value, |
217 | void |
218 | > |
219 | > { |
220 | static Result call( |
221 | const BoxedKernel& boxed_kernel_func, |
222 | const OperatorHandle& opHandle, |
223 | DispatchKeySet dispatchKeySet, |
224 | Args... args |
225 | ) { |
226 | torch::jit::Stack stack = boxArgs<Args...>(std::forward<Args>(args)...); |
227 | boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack); |
228 | |
229 | return guts::if_constexpr<!std::is_same<void, Result>::value>( |
230 | [&] (auto delay_check) { |
231 | // op has pushed one or more values onto the stack. |
232 | return delay_check(PopResult<Result>::call(stack)); |
233 | }, |
234 | [&] { |
235 | // op returns void, boxed kernel has pushed nothing onto stack. |
236 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
237 | stack.empty(), |
238 | "Boxed kernel was expected to return no values on the stack, " , |
239 | "but instead returned " , stack.size(), " values." |
240 | ); |
241 | } |
242 | ); |
243 | } |
244 | }; |
245 | |
246 | // |
247 | // 3. in-place ops take a single non-const Tensor reference |
248 | // as their first argument, and return it. |
249 | // |
250 | // Note: all signatures matching this pattern are assumed to be for such ops. |
251 | // Because of this, the generated BoxedKernelWrapper specializations simply |
252 | // return the in-place argument. |
253 | // |
254 | |
255 | template <class... OtherArgs> |
256 | struct BoxedKernelWrapper< |
257 | at::Tensor&(at::Tensor&, OtherArgs...), |
258 | std::enable_if_t<can_box_all<OtherArgs...>::value, void> |
259 | > { |
260 | static at::Tensor& call( |
261 | const BoxedKernel& boxed_kernel_func, |
262 | const OperatorHandle& opHandle, |
263 | DispatchKeySet dispatchKeySet, |
264 | at::Tensor& outArg, OtherArgs... otherArgs |
265 | ) { |
266 | torch::jit::Stack stack = boxArgs<at::Tensor&, OtherArgs...>(outArg, std::forward<OtherArgs>(otherArgs)...); |
267 | boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack); |
268 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
269 | stack.size() == 1, |
270 | "Boxed kernel was expected to return a single value on the stack, " , |
271 | "but instead returned " , stack.size(), " values." |
272 | ); |
273 | |
274 | return outArg; |
275 | } |
276 | }; |
277 | |
278 | // |
279 | // 3.5. In-process migration to make in-place ops take and return |
280 | // const references instead. |
281 | template <class... OtherArgs> |
282 | struct BoxedKernelWrapper< |
283 | const at::Tensor&(const at::Tensor&, OtherArgs...), |
284 | std::enable_if_t<can_box_all<OtherArgs...>::value, void> |
285 | > { |
286 | static const at::Tensor& call( |
287 | const BoxedKernel& boxed_kernel_func, |
288 | const OperatorHandle& opHandle, |
289 | DispatchKeySet dispatchKeySet, |
290 | const at::Tensor& outArg, OtherArgs... otherArgs |
291 | ) { |
292 | torch::jit::Stack stack = boxArgs(outArg, otherArgs...); |
293 | boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack); |
294 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
295 | stack.size() == 1, |
296 | "Boxed kernel was expected to return a single value on the stack, " , |
297 | "but instead returned " , stack.size(), " values." |
298 | ); |
299 | |
300 | return outArg; |
301 | } |
302 | }; |
303 | |
304 | // |
305 | // 4. out of place ops that take a single non-const Tensor reference as their |
306 | // final argument, and also return it. |
307 | // |
308 | // Note: all signatures matching this pattern are assumed to be for such ops. |
309 | // This assumption permits the generated BoxedKernelWrapper specializations to simply |
310 | // return out arguments. |
311 | // |
312 | template <class FirstArg, class... RestArgs> |
313 | struct BoxedKernelWrapper< |
314 | at::Tensor&(FirstArg, RestArgs...), |
315 | std::enable_if_t< |
316 | can_box_all<FirstArg, RestArgs...>::value |
317 | // this skips over in-place kernels with a non-const Tensor |
318 | // arg at the front, so those can unambiguously trigger the preceding specialization. |
319 | && !is_mutable_tensor_ref<FirstArg>::value, |
320 | void |
321 | > |
322 | > { |
323 | static at::Tensor& call( |
324 | const BoxedKernel& boxed_kernel_func, |
325 | const OperatorHandle& opHandle, |
326 | DispatchKeySet dispatchKeySet, |
327 | FirstArg firstArg, RestArgs... restArgs |
328 | ) { |
329 | torch::jit::Stack stack = boxArgs<FirstArg, RestArgs...>(std::forward<FirstArg>(firstArg), std::forward<RestArgs>(restArgs)...); |
330 | boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack); |
331 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
332 | stack.size() == 1, |
333 | "Boxed kernel was expected to return a single value on the stack, " , |
334 | "but instead returned " , stack.size(), " values." |
335 | ); |
336 | |
337 | // reusing restArgs after it has been forwarded here is ok because we know |
338 | // that the last element is of type `Tensor&`. |
339 | return std::get<sizeof...(RestArgs) - 1>(std::tuple<RestArgs...>{restArgs...}); |
340 | } |
341 | }; |
342 | |
343 | // |
344 | // 5. out of place ops that take multiple non-const Tensor references as their |
345 | // final arguments, and return them in a std::tuple. |
346 | // |
347 | // Note: all signatures matching this pattern are assumed to be for such ops. |
348 | // This assumption permits the generated BoxedKernelWrapper specializations to simply |
349 | // return the out arguments. |
350 | // |
351 | template <class Result, class... Args> |
352 | struct BoxedKernelWrapper< |
353 | Result(Args...), |
354 | std::enable_if_t< |
355 | can_box_all<Args...>::value && is_tuple_of_mutable_tensor_refs<Result>::value, |
356 | void |
357 | > |
358 | > { |
359 | static Result call( |
360 | const BoxedKernel& boxed_kernel_func, |
361 | const OperatorHandle& opHandle, |
362 | DispatchKeySet dispatchKeySet, |
363 | Args... args |
364 | ) { |
365 | using ArgTuple = std::tuple<Args...>; |
366 | constexpr int RetCount = std::tuple_size<Result>(); |
367 | |
368 | torch::jit::Stack stack = boxArgs<Args...>(std::forward<Args>(args)...); |
369 | boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack); |
370 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
371 | stack.size() == RetCount, |
372 | "Boxed kernel was expected to return " , RetCount, " values on the stack, " , |
373 | "but instead returned " , stack.size(), " values." |
374 | ); |
375 | |
376 | // reusing args after it has been forwarded here is ok because we know |
377 | // that the last RetCount elements are of type `Tensor&`. |
378 | auto result = guts::tuple_take<ArgTuple, -RetCount>(ArgTuple{std::forward<Args>(args)...}); |
379 | static_assert( |
380 | std::is_same<Result, decltype(result)>::value, |
381 | "The parameter list of an op returning a tuple of Tensor references " |
382 | "must end with an equal number of Tensor reference parameters." |
383 | ); |
384 | return result; |
385 | } |
386 | }; |
387 | |
388 | } // impl |
389 | } // c10 |
390 | |