1#pragma once
2
3#include <type_traits>
4
5#include <ATen/core/ivalue.h>
6#include <c10/util/Deprecated.h>
7#include <c10/util/irange.h>
8
9// TODO move this to c10 namespace
10
11namespace torch {
12namespace jit {
13
14using c10::IValue;
15using Stack = std::vector<IValue>;
16
17class Operation {
18 template <typename F, typename Arg>
19 using accepts = std::is_constructible<std::function<void(Arg)>, F&&>;
20
21 public:
22 template <typename F,
23 std::enable_if_t<accepts<F, Stack*>::value, int> = 0>
24 C10_DEPRECATED_MESSAGE("Please use void(Stack&) to register operator instead.")
25 Operation(F&& raw): op_([raw = std::forward<F>(raw)](Stack& stack) {
26 raw(&stack);
27 }) {}
28
29 template <typename F,
30 std::enable_if_t<accepts<F, Stack&>::value &&
31 !std::is_same<std::decay_t<F>, Operation>::value, int> = 0>
32 Operation(F&& op): op_(std::forward<F>(op)) {}
33
34 Operation(std::nullptr_t) noexcept {}
35
36 explicit operator bool() const noexcept {
37 return op_ ? true : false;
38 }
39
40 void operator()(Stack& stack) {
41 op_(stack);
42 }
43
44 template <typename T>
45 T* target() noexcept {
46 return op_.target<T>();
47 }
48
49 private:
50 std::function<void(Stack&)> op_;
51};
52
53// An operation with N inputs and M outputs pops the last N inputs off
54// the stack and pushes its M inputs onto the stack
55// before: <other stack items> I0, I1, ... IN <- stack.back()
56// after: <other stack items> O0, O1, ... OM
57// operations are defined this way so that ownership of inputs can be
58// transferred to the operation and it can incrementally drop ownership of
59// tensors when they become unneeded. For large operations, like 'run an entire
60// subgraph', this functionality is very important for minimizing gpu memory
61// usage return value is the relative 'offset' to jump to for the next
62// operation:
63// pc += 1 + offset
64// so a return value of 0 goes to the next instruction
65
66// treat the last N elements of the stack as a list, looking up
67// element i
68static inline IValue& peek(Stack& stack, size_t i, size_t N) {
69 return *(stack.end() - N + i);
70}
71static inline IValue& peek(Stack* stack, size_t i, size_t N) {
72 return peek(*stack, i, N);
73}
74static inline const IValue& peek(const Stack& stack, size_t i, size_t N) {
75 return *(stack.end() - N + i);
76}
77static inline const IValue& peek(const Stack* stack, size_t i, size_t N) {
78 return peek(*stack, i, N);
79}
80// treat the last N elements of the stack as a list, looking up the
81// slice starting at index i and having length len
82static inline at::ArrayRef<IValue> peekSlice(
83 const Stack& stack,
84 size_t i,
85 size_t len,
86 size_t N) {
87 return at::ArrayRef<IValue>(stack).slice(stack.size() - N + i, len);
88}
89static inline at::ArrayRef<IValue> last(const Stack& stack, size_t N) {
90 return peekSlice(stack, 0, N, N);
91}
92static inline at::ArrayRef<IValue> last(const Stack* stack, size_t N) {
93 return last(*stack, N);
94}
95static inline void drop(Stack& stack, size_t n) {
96 stack.erase(stack.end() - n, stack.end());
97}
98static inline void drop(Stack* stack, size_t n) {
99 drop(*stack, n);
100}
101static inline IValue pop(Stack& stack) {
102 auto r = std::move(stack.back());
103 stack.pop_back();
104 return r;
105}
106static inline IValue pop(Stack* stack) {
107 return pop(*stack);
108}
109static inline std::vector<IValue> pop(Stack& stack, size_t n) {
110 std::vector<IValue> result;
111 result.reserve(n);
112 for (const auto i : c10::irange(n)) {
113 result.push_back(std::move(peek(stack, i, n)));
114 }
115 drop(stack, n);
116 return result;
117}
118
119// variadic pop:
120// int64_t a; at::Tensor b;
121// pop(stack, a, b);
122// equivalent to:
123// b = pop(stack).toTensor();
124// a = pop(stack).toInt();
125template <typename... Types>
126static inline void pop(Stack& stack, Types&... args) {
127 size_t i = 0;
128 constexpr size_t N = sizeof...(args);
129 (void)std::initializer_list<int>{
130 (args = std::move(peek(stack, i++, N)).template to<Types>(), 0)...};
131 drop(stack, N);
132}
133template <typename... Types>
134static inline void pop(Stack* stack, Types&... args) {
135 pop(*stack, args...);
136}
137template <typename Type>
138static inline void push_one(Stack& stack, Type&& arg) {
139 stack.emplace_back(std::forward<Type>(arg));
140}
141
142static inline void push_one(Stack& stack, c10::TensorOptions options) {
143 stack.emplace_back(c10::typeMetaToScalarType(options.dtype()));
144 stack.emplace_back(options.layout());
145 stack.emplace_back(options.device());
146 stack.emplace_back(options.pinned_memory());
147}
148
149template <typename... Types>
150static inline void push(Stack& stack, Types&&... args) {
151 (void)std::initializer_list<int>{(push_one(stack, std::forward<Types>(args)), 0)...};
152}
153template <typename... Types>
154static inline void push(Stack* stack, Types&&... args) {
155 return push(*stack, std::forward<Types>(args)...);
156}
157template <class T>
158static inline void push_list_elements(Stack& stack, const c10::List<T>& elements) {
159 for (T elem : elements) {
160 stack.push_back(std::move(elem));
161 }
162}
163
164// The packer here is carefully written not to make any unnecessary
165// copies.
166
167// pack takes the return values of aten functions pushes them onto the stack
168template <typename T>
169inline void pack(Stack& stack, T&& v) {
170 stack.emplace_back(std::forward<T>(v));
171}
172template <typename T>
173inline void pack(Stack* stack, T&& v) {
174 pack(*stack, std::forward<T>(v));
175}
176
177template <std::size_t remaining, typename... Args>
178struct TuplePacker {
179 // NB: *Not* a universal reference.
180 static void execute(Stack& stack, std::tuple<Args...>&& t) {
181 // NB: The move here does not "destroy" the entire tuple, that is
182 // not what std::move does; only the particular tuple index
183 // processed here gets stolen.
184 pack(stack, std::get<sizeof...(Args) - remaining>(std::move(t)));
185 TuplePacker<remaining - 1, Args...>::execute(stack, std::move(t));
186 }
187};
188
189template <typename... Args>
190struct TuplePacker<0, Args...> {
191 static void execute(Stack& /*stack*/, std::tuple<Args...>&& /*t*/){};
192};
193
194template <typename... Args>
195inline void pack(Stack& stack, std::tuple<Args...>&& t) {
196 TuplePacker<sizeof...(Args), Args...>::execute(stack, std::move(t));
197}
198
199} // namespace jit
200} // namespace torch
201