1#pragma once
2
3// This file contains boxing (not unboxing) logic,
4// i.e. how to make a vector<IValue> from a set of concrete arguments.
5
6#include <ATen/core/ivalue.h>
7#include <ATen/core/stack.h>
8#include <c10/core/TensorOptions.h>
9
10#include <ATen/core/boxing/BoxedKernel.h>
11
12#include <c10/util/Metaprogramming.h>
13
14namespace c10 {
15namespace impl {
16
17//
18// utils
19//
20
21// is_mutable_tensor_ref
22template <class T> struct is_mutable_tensor_ref : std::false_type {};
23template <> struct is_mutable_tensor_ref<at::Tensor&> : std::true_type {};
24
25// is_tuple_of_mutable_tensor_refs
26//
27template <class T, class Enable = void>
28struct is_tuple_of_mutable_tensor_refs : std::false_type {};
29
30template <class T>
31struct is_tuple_of_mutable_tensor_refs<T, std::enable_if_t<guts::is_instantiation_of<std::tuple, T>::value, void>>
32: guts::typelist::all<is_mutable_tensor_ref, guts::typelist::from_tuple_t<T>>
33{};
34
35// has_ivalue_to<T> tests the presence/absence of instance method IValue::to<T>()
36//
37template <class T, class Enable = void>
38struct has_ivalue_to : std::false_type {};
39
40template <class T>
41struct has_ivalue_to<T, guts::void_t<decltype(std::declval<IValue>().to<T>())>>
42: std::true_type
43{};
44
45//
46// boxing predicates
47//
48
49// A boxable arg type is one that IValue has a constructor for.
50template <typename T>
51using can_box =
52 guts::disjunction<
53 std::is_constructible<IValue, std::decay_t<T>>,
54 // TensorOptions are not directly constructible into IValue,
55 // but torch::jit::push knows how to handle them
56 std::is_same<TensorOptions, std::decay_t<T>>
57 >;
58
59template <typename... Ts>
60using can_box_all = guts::conjunction<can_box<Ts>...>;
61
62// an unboxable result is one that can be extracted from an IValue
63template <typename T>
64using can_unbox =
65 guts::conjunction<
66 guts::disjunction<
67 has_ivalue_to<T>,
68 // void returns are ok
69 std::is_same<void, T>
70 >,
71 guts::negation<std::is_lvalue_reference<T>>
72 >;
73
74//
75// boxArgs - utility for pushing unboxed args onto IValue stack
76//
77template <class... Args>
78torch::jit::Stack boxArgs(Args... args) {
79 // TODO Reuse stack vector instead of allocating?
80 torch::jit::Stack stack;
81 stack.reserve(sizeof...(Args));
82 torch::jit::push(stack, std::forward<Args>(args)...);
83 return stack;
84}
85
86template <class T>
87static inline constexpr size_t boxed_size_one() {
88 static_assert(!std::is_same<std::decay_t<T>, c10::TensorOptions>::value, "need to patch this path to support TensorOptions passed by reference");
89 return 1;
90}
91
92// torch::jit::push pushes 4 values for a TensorOptions; this needs to
93// be kept in sync.
94template <>
95inline constexpr size_t boxed_size_one<c10::TensorOptions>() {
96 return 4;
97}
98
99// NOTE: this could probably be simplified with C++17 fold expressions.
100template <typename...>
101struct BoxedSize : std::integral_constant<size_t, 0> {};
102template <class T, class... Args>
103struct BoxedSize<T, Args...> : std::integral_constant<size_t, boxed_size_one<T>() + BoxedSize<Args...>::value> {};
104
105template <class... Args>
106static inline constexpr size_t boxed_size() {
107 return BoxedSize<Args...>::value;
108}
109
110using IValueAlignedStorage = std::aligned_storage_t<sizeof(IValue), alignof(IValue)>;
111
112template <typename T>
113C10_ALWAYS_INLINE_UNLESS_MOBILE void boxToStack(IValueAlignedStorage* dest, T& arg, int& lastIdx) {
114 new (&dest[lastIdx]) IValue(arg);
115 lastIdx++;
116}
117
118C10_ALWAYS_INLINE_UNLESS_MOBILE void boxToStack(IValueAlignedStorage* dest, c10::TensorOptions options, int& lastIdx) {
119 new (&dest[lastIdx++]) IValue(c10::typeMetaToScalarType(options.dtype()));
120 new (&dest[lastIdx++]) IValue(options.layout());
121 new (&dest[lastIdx++]) IValue(options.device());
122 new (&dest[lastIdx++]) IValue(options.pinned_memory());
123}
124
125inline void boxArgsToStack(IValueAlignedStorage*, int&) {}
126
127template<typename T, typename... Args>
128C10_ALWAYS_INLINE_UNLESS_MOBILE void boxArgsToStack(IValueAlignedStorage* dest, int& lastIdx, T& arg, Args &... args) {
129 boxToStack(dest, arg, lastIdx);
130 boxArgsToStack(dest, lastIdx, args...);
131}
132
133//
134// PopResult is a helper class whose specializations handle popping single and
135// multiple return values, respectively.
136//
137template <class Result>
138struct PopResult final {
139 static Result call(Stack& stack) {
140 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
141 stack.size() == 1,
142 "Boxed kernel was expected to return one value on the stack, ",
143 "but instead pushed ", stack.size(), " values."
144 );
145 return std::move(stack[0]).to<Result>();
146 }
147};
148
149template <class... Types>
150struct PopResult<std::tuple<Types...>> final {
151 using Result = std::tuple<Types...>;
152
153 static Result call(Stack& stack) {
154 // for tuple return types, boxed kernel has pushed multiple values onto the stack
155 constexpr int RetCount = sizeof...(Types);
156 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
157 stack.size() == RetCount,
158 "Boxed kernel was expected to return ", RetCount, " values on the stack, ",
159 "but instead pushed ", stack.size(), " values."
160 );
161 return pop_to_tuple_impl(stack, std::make_index_sequence<RetCount>());
162 }
163private:
164 // note: this has been moved into its own helper only to avoid a parse error on `indices` otherwise.
165 // I'm sure there's an incantation that slips it past the parser but eh
166 template <size_t... indices>
167 static Result pop_to_tuple_impl(Stack& stack, std::index_sequence<indices...>) {
168 return std::make_tuple((std::move(stack[indices]).to<Types>())...);
169 }
170};
171
172//
173// BoxedKernelWrapper
174//
175// For a given function type FT, BoxedKernelWrapper<FT> implements
176// a `call` method that
177// - takes a boxed kernel and unboxed arguments as specified by FT,
178// - calls `boxArgs` to box the arguments
179// - calls the boxed kernel
180// - unboxes and returns the result
181//
182// The partial specializations below handle various cases: in
183// particular, not all types appearing in op signatures are supported,
184// and ops returning references have nonstandard wrapper implementations.
185//
186
187// 1. The base specialization of BoxedKernelWrapper should never be instantiated.
188// A "no call method defined on BoxedKernelWrapper" compile error means that
189// an op signature has failed to trigger any of the partial specializations
190// that follow this one.
191//
192template <class FuncType, class Enable = void>
193struct BoxedKernelWrapper {
194 // The reason we're not just doing straight up static_assert(false, ...) here:
195 // Basically, the way to make sure a static_assert only fires if a template
196 // is actually instantiated (rather than every time the file is parsed) is to use
197 // template parameters in the expression, e.g. FuncType here. However, since
198 // `sizeof(FuncType) != sizeof(FuncType)` is always false, this has the same
199 // effect.
200 static_assert(sizeof(FuncType) != sizeof(FuncType),
201 "Function signature contains one or more unsupported parameter and/or return types. "
202 "Look for a nearby error like "
203 "\"'call' is not a member of 'c10::impl::BoxedKernelWrapper<(your function type), void>'\" "
204 "- (your function type) is the unsupported signature.");
205};
206
207//
208// 2. Supported signatures, other than those involving non-const Tensor refs -
209// i.e., "functional" ops.
210//
211
212template <class Result, class... Args>
213struct BoxedKernelWrapper<
214 Result(Args...),
215 std::enable_if_t<
216 can_box_all<Args...>::value && can_unbox<Result>::value && !is_tuple_of_mutable_tensor_refs<Result>::value,
217 void
218 >
219> {
220 static Result call(
221 const BoxedKernel& boxed_kernel_func,
222 const OperatorHandle& opHandle,
223 DispatchKeySet dispatchKeySet,
224 Args... args
225 ) {
226 torch::jit::Stack stack = boxArgs<Args...>(std::forward<Args>(args)...);
227 boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack);
228
229 return guts::if_constexpr<!std::is_same<void, Result>::value>(
230 [&] (auto delay_check) {
231 // op has pushed one or more values onto the stack.
232 return delay_check(PopResult<Result>::call(stack));
233 },
234 [&] {
235 // op returns void, boxed kernel has pushed nothing onto stack.
236 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
237 stack.empty(),
238 "Boxed kernel was expected to return no values on the stack, ",
239 "but instead returned ", stack.size(), " values."
240 );
241 }
242 );
243 }
244};
245
246//
247// 3. in-place ops take a single non-const Tensor reference
248// as their first argument, and return it.
249//
250// Note: all signatures matching this pattern are assumed to be for such ops.
251// Because of this, the generated BoxedKernelWrapper specializations simply
252// return the in-place argument.
253//
254
255template <class... OtherArgs>
256struct BoxedKernelWrapper<
257 at::Tensor&(at::Tensor&, OtherArgs...),
258 std::enable_if_t<can_box_all<OtherArgs...>::value, void>
259> {
260 static at::Tensor& call(
261 const BoxedKernel& boxed_kernel_func,
262 const OperatorHandle& opHandle,
263 DispatchKeySet dispatchKeySet,
264 at::Tensor& outArg, OtherArgs... otherArgs
265 ) {
266 torch::jit::Stack stack = boxArgs<at::Tensor&, OtherArgs...>(outArg, std::forward<OtherArgs>(otherArgs)...);
267 boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack);
268 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
269 stack.size() == 1,
270 "Boxed kernel was expected to return a single value on the stack, ",
271 "but instead returned ", stack.size(), " values."
272 );
273
274 return outArg;
275 }
276};
277
278//
279// 3.5. In-process migration to make in-place ops take and return
280// const references instead.
281template <class... OtherArgs>
282struct BoxedKernelWrapper<
283 const at::Tensor&(const at::Tensor&, OtherArgs...),
284 std::enable_if_t<can_box_all<OtherArgs...>::value, void>
285> {
286 static const at::Tensor& call(
287 const BoxedKernel& boxed_kernel_func,
288 const OperatorHandle& opHandle,
289 DispatchKeySet dispatchKeySet,
290 const at::Tensor& outArg, OtherArgs... otherArgs
291 ) {
292 torch::jit::Stack stack = boxArgs(outArg, otherArgs...);
293 boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack);
294 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
295 stack.size() == 1,
296 "Boxed kernel was expected to return a single value on the stack, ",
297 "but instead returned ", stack.size(), " values."
298 );
299
300 return outArg;
301 }
302};
303
304//
305// 4. out of place ops that take a single non-const Tensor reference as their
306// final argument, and also return it.
307//
308// Note: all signatures matching this pattern are assumed to be for such ops.
309// This assumption permits the generated BoxedKernelWrapper specializations to simply
310// return out arguments.
311//
312template <class FirstArg, class... RestArgs>
313struct BoxedKernelWrapper<
314 at::Tensor&(FirstArg, RestArgs...),
315 std::enable_if_t<
316 can_box_all<FirstArg, RestArgs...>::value
317 // this skips over in-place kernels with a non-const Tensor
318 // arg at the front, so those can unambiguously trigger the preceding specialization.
319 && !is_mutable_tensor_ref<FirstArg>::value,
320 void
321 >
322> {
323 static at::Tensor& call(
324 const BoxedKernel& boxed_kernel_func,
325 const OperatorHandle& opHandle,
326 DispatchKeySet dispatchKeySet,
327 FirstArg firstArg, RestArgs... restArgs
328 ) {
329 torch::jit::Stack stack = boxArgs<FirstArg, RestArgs...>(std::forward<FirstArg>(firstArg), std::forward<RestArgs>(restArgs)...);
330 boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack);
331 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
332 stack.size() == 1,
333 "Boxed kernel was expected to return a single value on the stack, ",
334 "but instead returned ", stack.size(), " values."
335 );
336
337 // reusing restArgs after it has been forwarded here is ok because we know
338 // that the last element is of type `Tensor&`.
339 return std::get<sizeof...(RestArgs) - 1>(std::tuple<RestArgs...>{restArgs...});
340 }
341};
342
343//
344// 5. out of place ops that take multiple non-const Tensor references as their
345// final arguments, and return them in a std::tuple.
346//
347// Note: all signatures matching this pattern are assumed to be for such ops.
348// This assumption permits the generated BoxedKernelWrapper specializations to simply
349// return the out arguments.
350//
351template <class Result, class... Args>
352struct BoxedKernelWrapper<
353 Result(Args...),
354 std::enable_if_t<
355 can_box_all<Args...>::value && is_tuple_of_mutable_tensor_refs<Result>::value,
356 void
357 >
358> {
359 static Result call(
360 const BoxedKernel& boxed_kernel_func,
361 const OperatorHandle& opHandle,
362 DispatchKeySet dispatchKeySet,
363 Args... args
364 ) {
365 using ArgTuple = std::tuple<Args...>;
366 constexpr int RetCount = std::tuple_size<Result>();
367
368 torch::jit::Stack stack = boxArgs<Args...>(std::forward<Args>(args)...);
369 boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack);
370 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
371 stack.size() == RetCount,
372 "Boxed kernel was expected to return ", RetCount, " values on the stack, ",
373 "but instead returned ", stack.size(), " values."
374 );
375
376 // reusing args after it has been forwarded here is ok because we know
377 // that the last RetCount elements are of type `Tensor&`.
378 auto result = guts::tuple_take<ArgTuple, -RetCount>(ArgTuple{std::forward<Args>(args)...});
379 static_assert(
380 std::is_same<Result, decltype(result)>::value,
381 "The parameter list of an op returning a tuple of Tensor references "
382 "must end with an equal number of Tensor reference parameters."
383 );
384 return result;
385 }
386};
387
388} // impl
389} // c10
390