1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/transforms/device_aware_visitors.h
22 * \brief Visitors which track the device for the current Relay expression and Relay Vars.
23 */
24
25#ifndef TVM_RELAY_TRANSFORMS_DEVICE_AWARE_VISITORS_H_
26#define TVM_RELAY_TRANSFORMS_DEVICE_AWARE_VISITORS_H_
27
28#include <dlpack/dlpack.h>
29#include <tvm/relay/expr.h>
30#include <tvm/relay/expr_functor.h>
31#include <tvm/relay/function.h>
32
33#include <unordered_map>
34#include <utility>
35#include <vector>
36
37#include "../op/annotation/annotation.h"
38#include "../op/memory/on_device.h"
39
40namespace tvm {
41namespace relay {
42namespace transform {
43
44/*!
45 * \brief Helper class for expression transformers which need to keep track of the \p VirtualDevice
46 * holding the results of expressions. This is recovered from function attributes and "on_device"
47 * CallNodes added by the PlanDevices pass.
48 *
49 * \sa \p DeviceAwareExpr{Functor,Visitor,Mutator}.
50 */
51class LexicalOnDeviceMixin {
52 protected:
53 explicit LexicalOnDeviceMixin(const Optional<IRModule>& maybe_mod);
54
55 /*!
56 * \brief Returns the \p VirtualDevice on which the result of \p expr should/will be stored,
57 * assuming {Push,Pop}{VirtualDevice,BoundVar} have been correctly called. May return the
58 * unconstrained \p VirtualDevice if the device planning pass has not been run.
59 */
60 VirtualDevice GetVirtualDevice(const Expr& expr) const;
61
62 /*! \brief Indicate a function body is being entered. */
63 void EnterFunctionBody();
64
65 /*! \brief Indicate a function body has been processed. */
66 void ExitFunctionBody();
67
68 /*! \brief Push an \p VirtualDevice onto the lexical VirtualDevice stack. Ignore if unconstrained.
69 */
70 void PushVirtualDevice(const VirtualDevice& virtual_device);
71
72 /*! \brief Pop an \p VirtualDevice from the lexical VirtualDevice stack. Ignore if stack is empty.
73 */
74 void PopVirtualDevice();
75
76 /*! \brief Remember that \p var will be stored at \p virtual_device. Ignore if unconstrained.
77 *
78 * CAUTION: Despite the name we don't support re-entering the same function body.
79 */
80 void PushBoundVar(Var var, const VirtualDevice& virtual_device);
81
82 /*! \brief Remove the binding for \p var to its \p VirtualDevice. Ignore if var is not bound. */
83 void PopBoundVar(const Var& var);
84
85 /*!
86 * \brief Returns the number of function definitions wrapping the currently visited expression.
87 */
88 int function_nesting() const { return function_nesting_; }
89
90 private:
91 /*!
92 * \brief The number of function bodies entered. Since many transforms need to distinguish global
93 * functions from local functions this supports the mixin's \p is_global() helper method.
94 */
95 int function_nesting_ = 0;
96
97 /*!
98 * \brief The stack of lexically enclosing "on_device" \p VirtualDevices, from outermost to
99 * innermost. When visiting an expression other than a variable we can assume the expression's
100 * result is to be stored on \p expr_virtual_devices.back().
101 */
102 std::vector<VirtualDevice> expr_virtual_devices_;
103
104 /*!
105 * \brief A map from in-scope local variables to their \p VirtualDevices. We may assume the
106 * variable is only ever bound to a value stored on this \p VirtualDevice at runtime.
107 *
108 * Note: We're playing it safe and keying by object refs here just in case the Relay expression
109 * being rewritten has no module or other global to keep it alive.
110 */
111 std::unordered_map<Var, VirtualDevice, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>
112 var_virtual_devices_;
113
114 /*!
115 * \brief A map from global variables to their \p VirtualDevices, ie the "result_virtual_device"
116 * of the function they are bound to in the module we are working on. We calculate and store this
117 * explicitly so that we don't need to hold on to any module, which is often in the process of
118 * being rewritten.
119 */
120 std::unordered_map<GlobalVar, VirtualDevice, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>
121 global_var_virtual_devices_;
122};
123
124template <typename FType>
125class DeviceAwareExprFunctor;
126
127/*!
128 * \brief ExprFunctor which tracks \p VirtualDevices. We only support 'visitor' style implementation
129 * with no additional arguments, thus this is equivalent to \p DeviceAwareExprVisitor without
130 * any memoization.
131 */
132template <>
133class DeviceAwareExprFunctor<void(const Expr& n)> : public ExprFunctor<void(const Expr& n)>,
134 public LexicalOnDeviceMixin {
135 private:
136 using TSuper = ExprFunctor<void(const Expr& n)>;
137
138 public:
139 explicit DeviceAwareExprFunctor(const Optional<IRModule>& maybe_mod)
140 : LexicalOnDeviceMixin(maybe_mod) {}
141
142 void VisitExpr_(const FunctionNode* function_node) {
143 if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
144 // No tracking inside primitive functions.
145 return DeviceAwareVisitExpr_(function_node);
146 } else {
147 // Function parameters come into scope.
148 for (auto param : function_node->params) {
149 PushBoundVar(param, param->virtual_device());
150 }
151 // Entering scope of function body.
152 VirtualDevice virtual_device = function_node->virtual_device();
153 VLOG(2) << "entering " << virtual_device << " for function:" << std::endl
154 << PrettyPrint(GetRef<Function>(function_node));
155 PushVirtualDevice(virtual_device);
156 EnterFunctionBody();
157
158 DeviceAwareVisitExpr_(function_node);
159
160 // Leaving scope of function body.
161 ExitFunctionBody();
162 PopVirtualDevice();
163 VLOG(2) << "leaving " << virtual_device << " for function:" << std::endl
164 << PrettyPrint(GetRef<Function>(function_node));
165 // Function parameters go out of scope.
166 for (size_t i = 0; i < function_node->params.size(); ++i) {
167 PopBoundVar(function_node->params[i]);
168 }
169 }
170 }
171
172 void VisitExpr_(const LetNode* let_node) {
173 PreVisitLetBlock_(let_node);
174 std::vector<const LetNode*> bindings;
175 Expr expr = GetRef<Expr>(let_node);
176 while (const auto* inner_let_node = expr.as<LetNode>()) {
177 // Let-bound var (in pre visited version) goes into scope.
178 // (We'll just assume this is a letrec.)
179 VirtualDevice virtual_device = GetVirtualDevice(inner_let_node->value);
180 VLOG(2) << "var '" << inner_let_node->var->name_hint() << "' has virtual device "
181 << virtual_device;
182 PushBoundVar(inner_let_node->var, virtual_device);
183 PreVisitLetBinding_(inner_let_node->var, inner_let_node->value);
184 bindings.emplace_back(inner_let_node);
185 expr = inner_let_node->body;
186 }
187
188 VisitExpr(expr);
189
190 for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) {
191 // Let-bound var goes out of scope.
192 const LetNode* visited_let_node = *itr;
193 PopBoundVar(visited_let_node->var);
194 PostVisitLet_(visited_let_node);
195 }
196 PostVisitLetBlock_(let_node);
197 }
198
199 void VisitExpr_(const CallNode* call_node) {
200 OnDeviceProps props = GetOnDeviceProps(call_node);
201 if (props.body.defined() && props.is_fixed()) {
202 // Entering lexical scope of "on_device" call.
203 VLOG(2) << "entering " << props.virtual_device << " for on_device:" << std::endl
204 << PrettyPrint(GetRef<Call>(call_node));
205 PushVirtualDevice(props.virtual_device);
206 VisitExpr(props.body);
207 // Leaving lexical scope of "on_device" call.
208 PopVirtualDevice();
209 VLOG(2) << "leaving " << props.virtual_device << " for on_device:" << std::endl
210 << PrettyPrint(GetRef<Call>(call_node));
211 } else {
212 DeviceAwareVisitExpr_(call_node);
213 }
214 }
215
216 /*!
217 * \brief These are as for VisitExpr_. \p VirtualDevices for expressions and function parameters
218 * will be tracked automatically. Default implementation defers to ExprMutator::VisitExpr_. For
219 * functions the function_nesting count will already include that of \p function_node.
220 */
221
222 virtual void DeviceAwareVisitExpr_(const FunctionNode* function_node) {
223 return TSuper::VisitExpr_(function_node);
224 }
225
226 virtual void DeviceAwareVisitExpr_(const CallNode* call_node) {
227 return TSuper::VisitExpr_(call_node);
228 }
229
230 /*!
231 * \brief Visit the first let in a chain of let expressions before any let bindings or final
232 * body has been visited. Default implementation is a no-op.
233 */
234 virtual void PreVisitLetBlock_(const LetNode* let_node) { /* no-op */
235 }
236
237 /*!
238 * \brief Visit a let-bound expression before the let body has been visited. Devices for the
239 * let-bound variable will be tracked automatically. Default implementation just visits var and
240 * value.
241 */
242 virtual void PreVisitLetBinding_(const Var& var, const Expr& value) {
243 VisitExpr(var);
244 VisitExpr(value);
245 }
246
247 /*!
248 * \brief Visit a let expression after the let-bound value and body have been visited.
249 * Default implementation is a no-op.
250 */
251 virtual void PostVisitLet_(const LetNode* let_node) { /* no-op */
252 }
253
254 /*!
255 * \brief Visit the first let in a chain of let expressions after it has been visited.
256 * Default implementation is a no-op.
257 */
258 virtual void PostVisitLetBlock_(const LetNode* let_node) {}
259};
260
261/*! \brief ExprVisitor which tracks \p VirtualDevices. */
262class DeviceAwareExprVisitor : public ExprVisitor, public LexicalOnDeviceMixin {
263 public:
264 explicit DeviceAwareExprVisitor(const Optional<IRModule>& maybe_mod)
265 : LexicalOnDeviceMixin(maybe_mod) {}
266
267 using ExprVisitor::VisitExpr_;
268
269 void VisitExpr_(const FunctionNode* function_node) final;
270 void VisitExpr_(const LetNode* let_node) final;
271 void VisitExpr_(const CallNode* call_node) final;
272
273 /*!
274 * \brief These are as for VisitExpr_. \p VirtualDevices for expressions and function parameters
275 * will be tracked automatically. Default implementation defers to ExprMutator::VisitExpr_. For
276 * functions the function_nesting count will already include that of \p function_node.
277 */
278 virtual void DeviceAwareVisitExpr_(const FunctionNode* function_node);
279 virtual void DeviceAwareVisitExpr_(const CallNode* call_node);
280
281 /*!
282 * \brief Visit the first let in a chain of let expressions before any let bindings or final
283 * body has been visited. Default implementation is a no-op.
284 */
285 virtual void PreVisitLetBlock_(const LetNode* let_node);
286
287 /*!
288 * \brief Visit a let-bound expression before the let body has been visited. \p VirtualDevices for
289 * the let-bound variable will be tracked automatically. Default implementation just visits var
290 * and value.
291 */
292 virtual void PreVisitLetBinding_(const Var& var, const Expr& value);
293
294 /*!
295 * \brief Visit a let expression after the let-bound value and body have been visited.
296 * Default implementation is a no-op.
297 */
298 virtual void PostVisitLet_(const LetNode* let_node);
299
300 /*!
301 * \brief Visit the first let in a chain of let expressions after it has been visited.
302 * Default implementation is a no-op.
303 */
304 virtual void PostVisitLetBlock_(const LetNode* let_node);
305};
306
307/*! \brief ExprMutator which tracks \p VirtualDevices. */
308class DeviceAwareExprMutator : public ExprMutator, public LexicalOnDeviceMixin {
309 public:
310 explicit DeviceAwareExprMutator(const Optional<IRModule>& maybe_mod)
311 : LexicalOnDeviceMixin(maybe_mod) {}
312
313 Expr VisitExpr_(const FunctionNode* function_node) final;
314 Expr VisitExpr_(const LetNode* let_node) final;
315 Expr VisitExpr_(const CallNode* call_node) final;
316
317 /*!
318 * \brief These are as for VisitExpr_. \p VirtualDevices for expressions and function parameters
319 * will be tracked automatically. Default implementation defers to ExprMutator::VisitExpr_. For
320 * functions the function_nesting count will already include that of \p function_node.
321 */
322 virtual Expr DeviceAwareVisitExpr_(const FunctionNode* function_node);
323 virtual Expr DeviceAwareVisitExpr_(const CallNode* call_node);
324
325 /*!
326 * \brief Visit the first let in a chain of let expressions before any let bindings or final
327 * body has been visited. Default implementation is a no-op.
328 */
329 virtual void PreVisitLetBlock_(const LetNode* let_node);
330
331 /*!
332 * \brief Visit a let-bound expression before the let body has been visited. \p VirtualDevices for
333 * the let-bound variable will be tracked automatically. Default implementation just visits var
334 * and value.
335 */
336 virtual std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value);
337
338 /*!
339 * \brief Visit a let expression after the let-bound value and body have been visited.
340 * Default implementation just returns a reference to the post-visited node.
341 */
342 virtual Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node);
343
344 /*!
345 * \brief Visit the first let in a chain of let expressions after it has been visited.
346 * Default implementation returns reference to let node.
347 */
348 virtual Expr PostVisitLetBlock_(const LetNode* pre_let_node, const LetNode* post_let_node);
349};
350
351/*!
352 * \brief Returs a map from Relay expression node to its virtual device using the annotations
353 * and \p virtual_device fields of \p expr. The map's lifetime must not exceed that of
354 * \p expr itself.
355 */
356std::unordered_map<const ExprNode*, VirtualDevice> RecoverVirtualDeviceMap(const IRModule& mod,
357 const Expr& expr);
358
359} // namespace transform
360} // namespace relay
361} // namespace tvm
362
363#endif // TVM_RELAY_TRANSFORMS_DEVICE_AWARE_VISITORS_H_
364