1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/transforms/device_aware_visitors.h |
22 | * \brief Visitors which track the device for the current Relay expression and Relay Vars. |
23 | */ |
24 | |
25 | #ifndef TVM_RELAY_TRANSFORMS_DEVICE_AWARE_VISITORS_H_ |
26 | #define TVM_RELAY_TRANSFORMS_DEVICE_AWARE_VISITORS_H_ |
27 | |
28 | #include <dlpack/dlpack.h> |
29 | #include <tvm/relay/expr.h> |
30 | #include <tvm/relay/expr_functor.h> |
31 | #include <tvm/relay/function.h> |
32 | |
33 | #include <unordered_map> |
34 | #include <utility> |
35 | #include <vector> |
36 | |
37 | #include "../op/annotation/annotation.h" |
38 | #include "../op/memory/on_device.h" |
39 | |
40 | namespace tvm { |
41 | namespace relay { |
42 | namespace transform { |
43 | |
44 | /*! |
45 | * \brief Helper class for expression transformers which need to keep track of the \p VirtualDevice |
46 | * holding the results of expressions. This is recovered from function attributes and "on_device" |
47 | * CallNodes added by the PlanDevices pass. |
48 | * |
49 | * \sa \p DeviceAwareExpr{Functor,Visitor,Mutator}. |
50 | */ |
51 | class LexicalOnDeviceMixin { |
52 | protected: |
53 | explicit LexicalOnDeviceMixin(const Optional<IRModule>& maybe_mod); |
54 | |
55 | /*! |
56 | * \brief Returns the \p VirtualDevice on which the result of \p expr should/will be stored, |
57 | * assuming {Push,Pop}{VirtualDevice,BoundVar} have been correctly called. May return the |
58 | * unconstrained \p VirtualDevice if the device planning pass has not been run. |
59 | */ |
60 | VirtualDevice GetVirtualDevice(const Expr& expr) const; |
61 | |
62 | /*! \brief Indicate a function body is being entered. */ |
63 | void EnterFunctionBody(); |
64 | |
65 | /*! \brief Indicate a function body has been processed. */ |
66 | void ExitFunctionBody(); |
67 | |
68 | /*! \brief Push an \p VirtualDevice onto the lexical VirtualDevice stack. Ignore if unconstrained. |
69 | */ |
70 | void PushVirtualDevice(const VirtualDevice& virtual_device); |
71 | |
72 | /*! \brief Pop an \p VirtualDevice from the lexical VirtualDevice stack. Ignore if stack is empty. |
73 | */ |
74 | void PopVirtualDevice(); |
75 | |
76 | /*! \brief Remember that \p var will be stored at \p virtual_device. Ignore if unconstrained. |
77 | * |
78 | * CAUTION: Despite the name we don't support re-entering the same function body. |
79 | */ |
80 | void PushBoundVar(Var var, const VirtualDevice& virtual_device); |
81 | |
82 | /*! \brief Remove the binding for \p var to its \p VirtualDevice. Ignore if var is not bound. */ |
83 | void PopBoundVar(const Var& var); |
84 | |
85 | /*! |
86 | * \brief Returns the number of function definitions wrapping the currently visited expression. |
87 | */ |
88 | int function_nesting() const { return function_nesting_; } |
89 | |
90 | private: |
91 | /*! |
92 | * \brief The number of function bodies entered. Since many transforms need to distinguish global |
93 | * functions from local functions this supports the mixin's \p is_global() helper method. |
94 | */ |
95 | int function_nesting_ = 0; |
96 | |
97 | /*! |
98 | * \brief The stack of lexically enclosing "on_device" \p VirtualDevices, from outermost to |
99 | * innermost. When visiting an expression other than a variable we can assume the expression's |
100 | * result is to be stored on \p expr_virtual_devices.back(). |
101 | */ |
102 | std::vector<VirtualDevice> expr_virtual_devices_; |
103 | |
104 | /*! |
105 | * \brief A map from in-scope local variables to their \p VirtualDevices. We may assume the |
106 | * variable is only ever bound to a value stored on this \p VirtualDevice at runtime. |
107 | * |
108 | * Note: We're playing it safe and keying by object refs here just in case the Relay expression |
109 | * being rewritten has no module or other global to keep it alive. |
110 | */ |
111 | std::unordered_map<Var, VirtualDevice, runtime::ObjectPtrHash, runtime::ObjectPtrEqual> |
112 | var_virtual_devices_; |
113 | |
114 | /*! |
115 | * \brief A map from global variables to their \p VirtualDevices, ie the "result_virtual_device" |
116 | * of the function they are bound to in the module we are working on. We calculate and store this |
117 | * explicitly so that we don't need to hold on to any module, which is often in the process of |
118 | * being rewritten. |
119 | */ |
120 | std::unordered_map<GlobalVar, VirtualDevice, runtime::ObjectPtrHash, runtime::ObjectPtrEqual> |
121 | global_var_virtual_devices_; |
122 | }; |
123 | |
124 | template <typename FType> |
125 | class DeviceAwareExprFunctor; |
126 | |
127 | /*! |
128 | * \brief ExprFunctor which tracks \p VirtualDevices. We only support 'visitor' style implementation |
129 | * with no additional arguments, thus this is equivalent to \p DeviceAwareExprVisitor without |
130 | * any memoization. |
131 | */ |
132 | template <> |
133 | class DeviceAwareExprFunctor<void(const Expr& n)> : public ExprFunctor<void(const Expr& n)>, |
134 | public LexicalOnDeviceMixin { |
135 | private: |
136 | using TSuper = ExprFunctor<void(const Expr& n)>; |
137 | |
138 | public: |
139 | explicit DeviceAwareExprFunctor(const Optional<IRModule>& maybe_mod) |
140 | : LexicalOnDeviceMixin(maybe_mod) {} |
141 | |
142 | void VisitExpr_(const FunctionNode* function_node) { |
143 | if (function_node->HasNonzeroAttr(attr::kPrimitive)) { |
144 | // No tracking inside primitive functions. |
145 | return DeviceAwareVisitExpr_(function_node); |
146 | } else { |
147 | // Function parameters come into scope. |
148 | for (auto param : function_node->params) { |
149 | PushBoundVar(param, param->virtual_device()); |
150 | } |
151 | // Entering scope of function body. |
152 | VirtualDevice virtual_device = function_node->virtual_device(); |
153 | VLOG(2) << "entering " << virtual_device << " for function:" << std::endl |
154 | << PrettyPrint(GetRef<Function>(function_node)); |
155 | PushVirtualDevice(virtual_device); |
156 | EnterFunctionBody(); |
157 | |
158 | DeviceAwareVisitExpr_(function_node); |
159 | |
160 | // Leaving scope of function body. |
161 | ExitFunctionBody(); |
162 | PopVirtualDevice(); |
163 | VLOG(2) << "leaving " << virtual_device << " for function:" << std::endl |
164 | << PrettyPrint(GetRef<Function>(function_node)); |
165 | // Function parameters go out of scope. |
166 | for (size_t i = 0; i < function_node->params.size(); ++i) { |
167 | PopBoundVar(function_node->params[i]); |
168 | } |
169 | } |
170 | } |
171 | |
172 | void VisitExpr_(const LetNode* let_node) { |
173 | PreVisitLetBlock_(let_node); |
174 | std::vector<const LetNode*> bindings; |
175 | Expr expr = GetRef<Expr>(let_node); |
176 | while (const auto* inner_let_node = expr.as<LetNode>()) { |
177 | // Let-bound var (in pre visited version) goes into scope. |
178 | // (We'll just assume this is a letrec.) |
179 | VirtualDevice virtual_device = GetVirtualDevice(inner_let_node->value); |
180 | VLOG(2) << "var '" << inner_let_node->var->name_hint() << "' has virtual device " |
181 | << virtual_device; |
182 | PushBoundVar(inner_let_node->var, virtual_device); |
183 | PreVisitLetBinding_(inner_let_node->var, inner_let_node->value); |
184 | bindings.emplace_back(inner_let_node); |
185 | expr = inner_let_node->body; |
186 | } |
187 | |
188 | VisitExpr(expr); |
189 | |
190 | for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { |
191 | // Let-bound var goes out of scope. |
192 | const LetNode* visited_let_node = *itr; |
193 | PopBoundVar(visited_let_node->var); |
194 | PostVisitLet_(visited_let_node); |
195 | } |
196 | PostVisitLetBlock_(let_node); |
197 | } |
198 | |
199 | void VisitExpr_(const CallNode* call_node) { |
200 | OnDeviceProps props = GetOnDeviceProps(call_node); |
201 | if (props.body.defined() && props.is_fixed()) { |
202 | // Entering lexical scope of "on_device" call. |
203 | VLOG(2) << "entering " << props.virtual_device << " for on_device:" << std::endl |
204 | << PrettyPrint(GetRef<Call>(call_node)); |
205 | PushVirtualDevice(props.virtual_device); |
206 | VisitExpr(props.body); |
207 | // Leaving lexical scope of "on_device" call. |
208 | PopVirtualDevice(); |
209 | VLOG(2) << "leaving " << props.virtual_device << " for on_device:" << std::endl |
210 | << PrettyPrint(GetRef<Call>(call_node)); |
211 | } else { |
212 | DeviceAwareVisitExpr_(call_node); |
213 | } |
214 | } |
215 | |
216 | /*! |
217 | * \brief These are as for VisitExpr_. \p VirtualDevices for expressions and function parameters |
218 | * will be tracked automatically. Default implementation defers to ExprMutator::VisitExpr_. For |
219 | * functions the function_nesting count will already include that of \p function_node. |
220 | */ |
221 | |
222 | virtual void DeviceAwareVisitExpr_(const FunctionNode* function_node) { |
223 | return TSuper::VisitExpr_(function_node); |
224 | } |
225 | |
226 | virtual void DeviceAwareVisitExpr_(const CallNode* call_node) { |
227 | return TSuper::VisitExpr_(call_node); |
228 | } |
229 | |
230 | /*! |
231 | * \brief Visit the first let in a chain of let expressions before any let bindings or final |
232 | * body has been visited. Default implementation is a no-op. |
233 | */ |
234 | virtual void PreVisitLetBlock_(const LetNode* let_node) { /* no-op */ |
235 | } |
236 | |
237 | /*! |
238 | * \brief Visit a let-bound expression before the let body has been visited. Devices for the |
239 | * let-bound variable will be tracked automatically. Default implementation just visits var and |
240 | * value. |
241 | */ |
242 | virtual void PreVisitLetBinding_(const Var& var, const Expr& value) { |
243 | VisitExpr(var); |
244 | VisitExpr(value); |
245 | } |
246 | |
247 | /*! |
248 | * \brief Visit a let expression after the let-bound value and body have been visited. |
249 | * Default implementation is a no-op. |
250 | */ |
251 | virtual void PostVisitLet_(const LetNode* let_node) { /* no-op */ |
252 | } |
253 | |
254 | /*! |
255 | * \brief Visit the first let in a chain of let expressions after it has been visited. |
256 | * Default implementation is a no-op. |
257 | */ |
258 | virtual void PostVisitLetBlock_(const LetNode* let_node) {} |
259 | }; |
260 | |
261 | /*! \brief ExprVisitor which tracks \p VirtualDevices. */ |
262 | class DeviceAwareExprVisitor : public ExprVisitor, public LexicalOnDeviceMixin { |
263 | public: |
264 | explicit DeviceAwareExprVisitor(const Optional<IRModule>& maybe_mod) |
265 | : LexicalOnDeviceMixin(maybe_mod) {} |
266 | |
267 | using ExprVisitor::VisitExpr_; |
268 | |
269 | void VisitExpr_(const FunctionNode* function_node) final; |
270 | void VisitExpr_(const LetNode* let_node) final; |
271 | void VisitExpr_(const CallNode* call_node) final; |
272 | |
273 | /*! |
274 | * \brief These are as for VisitExpr_. \p VirtualDevices for expressions and function parameters |
275 | * will be tracked automatically. Default implementation defers to ExprMutator::VisitExpr_. For |
276 | * functions the function_nesting count will already include that of \p function_node. |
277 | */ |
278 | virtual void DeviceAwareVisitExpr_(const FunctionNode* function_node); |
279 | virtual void DeviceAwareVisitExpr_(const CallNode* call_node); |
280 | |
281 | /*! |
282 | * \brief Visit the first let in a chain of let expressions before any let bindings or final |
283 | * body has been visited. Default implementation is a no-op. |
284 | */ |
285 | virtual void PreVisitLetBlock_(const LetNode* let_node); |
286 | |
287 | /*! |
288 | * \brief Visit a let-bound expression before the let body has been visited. \p VirtualDevices for |
289 | * the let-bound variable will be tracked automatically. Default implementation just visits var |
290 | * and value. |
291 | */ |
292 | virtual void PreVisitLetBinding_(const Var& var, const Expr& value); |
293 | |
294 | /*! |
295 | * \brief Visit a let expression after the let-bound value and body have been visited. |
296 | * Default implementation is a no-op. |
297 | */ |
298 | virtual void PostVisitLet_(const LetNode* let_node); |
299 | |
300 | /*! |
301 | * \brief Visit the first let in a chain of let expressions after it has been visited. |
302 | * Default implementation is a no-op. |
303 | */ |
304 | virtual void PostVisitLetBlock_(const LetNode* let_node); |
305 | }; |
306 | |
307 | /*! \brief ExprMutator which tracks \p VirtualDevices. */ |
308 | class DeviceAwareExprMutator : public ExprMutator, public LexicalOnDeviceMixin { |
309 | public: |
310 | explicit DeviceAwareExprMutator(const Optional<IRModule>& maybe_mod) |
311 | : LexicalOnDeviceMixin(maybe_mod) {} |
312 | |
313 | Expr VisitExpr_(const FunctionNode* function_node) final; |
314 | Expr VisitExpr_(const LetNode* let_node) final; |
315 | Expr VisitExpr_(const CallNode* call_node) final; |
316 | |
317 | /*! |
318 | * \brief These are as for VisitExpr_. \p VirtualDevices for expressions and function parameters |
319 | * will be tracked automatically. Default implementation defers to ExprMutator::VisitExpr_. For |
320 | * functions the function_nesting count will already include that of \p function_node. |
321 | */ |
322 | virtual Expr DeviceAwareVisitExpr_(const FunctionNode* function_node); |
323 | virtual Expr DeviceAwareVisitExpr_(const CallNode* call_node); |
324 | |
325 | /*! |
326 | * \brief Visit the first let in a chain of let expressions before any let bindings or final |
327 | * body has been visited. Default implementation is a no-op. |
328 | */ |
329 | virtual void PreVisitLetBlock_(const LetNode* let_node); |
330 | |
331 | /*! |
332 | * \brief Visit a let-bound expression before the let body has been visited. \p VirtualDevices for |
333 | * the let-bound variable will be tracked automatically. Default implementation just visits var |
334 | * and value. |
335 | */ |
336 | virtual std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value); |
337 | |
338 | /*! |
339 | * \brief Visit a let expression after the let-bound value and body have been visited. |
340 | * Default implementation just returns a reference to the post-visited node. |
341 | */ |
342 | virtual Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node); |
343 | |
344 | /*! |
345 | * \brief Visit the first let in a chain of let expressions after it has been visited. |
346 | * Default implementation returns reference to let node. |
347 | */ |
348 | virtual Expr PostVisitLetBlock_(const LetNode* pre_let_node, const LetNode* post_let_node); |
349 | }; |
350 | |
351 | /*! |
352 | * \brief Returs a map from Relay expression node to its virtual device using the annotations |
353 | * and \p virtual_device fields of \p expr. The map's lifetime must not exceed that of |
354 | * \p expr itself. |
355 | */ |
356 | std::unordered_map<const ExprNode*, VirtualDevice> RecoverVirtualDeviceMap(const IRModule& mod, |
357 | const Expr& expr); |
358 | |
359 | } // namespace transform |
360 | } // namespace relay |
361 | } // namespace tvm |
362 | |
363 | #endif // TVM_RELAY_TRANSFORMS_DEVICE_AWARE_VISITORS_H_ |
364 | |