1 | #pragma once |
2 | |
3 | #include <type_traits> |
4 | |
5 | #include <ATen/core/ivalue.h> |
6 | #include <c10/util/Deprecated.h> |
7 | #include <c10/util/irange.h> |
8 | |
9 | // TODO move this to c10 namespace |
10 | |
11 | namespace torch { |
12 | namespace jit { |
13 | |
14 | using c10::IValue; |
15 | using Stack = std::vector<IValue>; |
16 | |
17 | class Operation { |
18 | template <typename F, typename Arg> |
19 | using accepts = std::is_constructible<std::function<void(Arg)>, F&&>; |
20 | |
21 | public: |
22 | template <typename F, |
23 | std::enable_if_t<accepts<F, Stack*>::value, int> = 0> |
24 | C10_DEPRECATED_MESSAGE("Please use void(Stack&) to register operator instead." ) |
25 | Operation(F&& raw): op_([raw = std::forward<F>(raw)](Stack& stack) { |
26 | raw(&stack); |
27 | }) {} |
28 | |
29 | template <typename F, |
30 | std::enable_if_t<accepts<F, Stack&>::value && |
31 | !std::is_same<std::decay_t<F>, Operation>::value, int> = 0> |
32 | Operation(F&& op): op_(std::forward<F>(op)) {} |
33 | |
34 | Operation(std::nullptr_t) noexcept {} |
35 | |
36 | explicit operator bool() const noexcept { |
37 | return op_ ? true : false; |
38 | } |
39 | |
40 | void operator()(Stack& stack) { |
41 | op_(stack); |
42 | } |
43 | |
44 | template <typename T> |
45 | T* target() noexcept { |
46 | return op_.target<T>(); |
47 | } |
48 | |
49 | private: |
50 | std::function<void(Stack&)> op_; |
51 | }; |
52 | |
53 | // An operation with N inputs and M outputs pops the last N inputs off |
54 | // the stack and pushes its M inputs onto the stack |
55 | // before: <other stack items> I0, I1, ... IN <- stack.back() |
56 | // after: <other stack items> O0, O1, ... OM |
57 | // operations are defined this way so that ownership of inputs can be |
58 | // transferred to the operation and it can incrementally drop ownership of |
59 | // tensors when they become unneeded. For large operations, like 'run an entire |
60 | // subgraph', this functionality is very important for minimizing gpu memory |
61 | // usage return value is the relative 'offset' to jump to for the next |
62 | // operation: |
63 | // pc += 1 + offset |
64 | // so a return value of 0 goes to the next instruction |
65 | |
66 | // treat the last N elements of the stack as a list, looking up |
67 | // element i |
68 | static inline IValue& peek(Stack& stack, size_t i, size_t N) { |
69 | return *(stack.end() - N + i); |
70 | } |
71 | static inline IValue& peek(Stack* stack, size_t i, size_t N) { |
72 | return peek(*stack, i, N); |
73 | } |
74 | static inline const IValue& peek(const Stack& stack, size_t i, size_t N) { |
75 | return *(stack.end() - N + i); |
76 | } |
77 | static inline const IValue& peek(const Stack* stack, size_t i, size_t N) { |
78 | return peek(*stack, i, N); |
79 | } |
80 | // treat the last N elements of the stack as a list, looking up the |
81 | // slice starting at index i and having length len |
82 | static inline at::ArrayRef<IValue> peekSlice( |
83 | const Stack& stack, |
84 | size_t i, |
85 | size_t len, |
86 | size_t N) { |
87 | return at::ArrayRef<IValue>(stack).slice(stack.size() - N + i, len); |
88 | } |
89 | static inline at::ArrayRef<IValue> last(const Stack& stack, size_t N) { |
90 | return peekSlice(stack, 0, N, N); |
91 | } |
92 | static inline at::ArrayRef<IValue> last(const Stack* stack, size_t N) { |
93 | return last(*stack, N); |
94 | } |
95 | static inline void drop(Stack& stack, size_t n) { |
96 | stack.erase(stack.end() - n, stack.end()); |
97 | } |
98 | static inline void drop(Stack* stack, size_t n) { |
99 | drop(*stack, n); |
100 | } |
101 | static inline IValue pop(Stack& stack) { |
102 | auto r = std::move(stack.back()); |
103 | stack.pop_back(); |
104 | return r; |
105 | } |
106 | static inline IValue pop(Stack* stack) { |
107 | return pop(*stack); |
108 | } |
109 | static inline std::vector<IValue> pop(Stack& stack, size_t n) { |
110 | std::vector<IValue> result; |
111 | result.reserve(n); |
112 | for (const auto i : c10::irange(n)) { |
113 | result.push_back(std::move(peek(stack, i, n))); |
114 | } |
115 | drop(stack, n); |
116 | return result; |
117 | } |
118 | |
119 | // variadic pop: |
120 | // int64_t a; at::Tensor b; |
121 | // pop(stack, a, b); |
122 | // equivalent to: |
123 | // b = pop(stack).toTensor(); |
124 | // a = pop(stack).toInt(); |
125 | template <typename... Types> |
126 | static inline void pop(Stack& stack, Types&... args) { |
127 | size_t i = 0; |
128 | constexpr size_t N = sizeof...(args); |
129 | (void)std::initializer_list<int>{ |
130 | (args = std::move(peek(stack, i++, N)).template to<Types>(), 0)...}; |
131 | drop(stack, N); |
132 | } |
133 | template <typename... Types> |
134 | static inline void pop(Stack* stack, Types&... args) { |
135 | pop(*stack, args...); |
136 | } |
137 | template <typename Type> |
138 | static inline void push_one(Stack& stack, Type&& arg) { |
139 | stack.emplace_back(std::forward<Type>(arg)); |
140 | } |
141 | |
142 | static inline void push_one(Stack& stack, c10::TensorOptions options) { |
143 | stack.emplace_back(c10::typeMetaToScalarType(options.dtype())); |
144 | stack.emplace_back(options.layout()); |
145 | stack.emplace_back(options.device()); |
146 | stack.emplace_back(options.pinned_memory()); |
147 | } |
148 | |
149 | template <typename... Types> |
150 | static inline void push(Stack& stack, Types&&... args) { |
151 | (void)std::initializer_list<int>{(push_one(stack, std::forward<Types>(args)), 0)...}; |
152 | } |
153 | template <typename... Types> |
154 | static inline void push(Stack* stack, Types&&... args) { |
155 | return push(*stack, std::forward<Types>(args)...); |
156 | } |
157 | template <class T> |
158 | static inline void push_list_elements(Stack& stack, const c10::List<T>& elements) { |
159 | for (T elem : elements) { |
160 | stack.push_back(std::move(elem)); |
161 | } |
162 | } |
163 | |
164 | // The packer here is carefully written not to make any unnecessary |
165 | // copies. |
166 | |
167 | // pack takes the return values of aten functions pushes them onto the stack |
168 | template <typename T> |
169 | inline void pack(Stack& stack, T&& v) { |
170 | stack.emplace_back(std::forward<T>(v)); |
171 | } |
172 | template <typename T> |
173 | inline void pack(Stack* stack, T&& v) { |
174 | pack(*stack, std::forward<T>(v)); |
175 | } |
176 | |
177 | template <std::size_t remaining, typename... Args> |
178 | struct TuplePacker { |
179 | // NB: *Not* a universal reference. |
180 | static void execute(Stack& stack, std::tuple<Args...>&& t) { |
181 | // NB: The move here does not "destroy" the entire tuple, that is |
182 | // not what std::move does; only the particular tuple index |
183 | // processed here gets stolen. |
184 | pack(stack, std::get<sizeof...(Args) - remaining>(std::move(t))); |
185 | TuplePacker<remaining - 1, Args...>::execute(stack, std::move(t)); |
186 | } |
187 | }; |
188 | |
189 | template <typename... Args> |
190 | struct TuplePacker<0, Args...> { |
191 | static void execute(Stack& /*stack*/, std::tuple<Args...>&& /*t*/){}; |
192 | }; |
193 | |
194 | template <typename... Args> |
195 | inline void pack(Stack& stack, std::tuple<Args...>&& t) { |
196 | TuplePacker<sizeof...(Args), Args...>::execute(stack, std::move(t)); |
197 | } |
198 | |
199 | } // namespace jit |
200 | } // namespace torch |
201 | |