1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/transforms/device_aware_visitors.cc
22 * \brief Visitors which track the device for the current Relay expression.
23 */
24
25#include "./device_aware_visitors.h"
26
27namespace tvm {
28namespace relay {
29namespace transform {
30
31// TODO(mbs): This machinery can be used a) on expressions/modules which have not had
32// device planning run, and b) on expressions for which we've not kept track of their
33// containing module. For now we'll handle b) by being forgiving as possible when recovering
34// the device for an expression, and we'll support a) the same way. But better would be
35// to ICHECK fail when, eg, a variable is not in scope or the lexical device stack is empty.
36
37LexicalOnDeviceMixin::LexicalOnDeviceMixin(const Optional<IRModule>& maybe_mod) {
38 if (maybe_mod) {
39 for (const auto& kv : maybe_mod.value()->functions) {
40 if (const auto* function_node = kv.second.as<FunctionNode>()) {
41 VirtualDevice virtual_device = function_node->virtual_device();
42 if (!virtual_device->IsFullyUnconstrained()) {
43 VLOG(2) << "global '" << kv.first->name_hint << "' has virtual device " << virtual_device;
44 global_var_virtual_devices_.emplace(kv.first, virtual_device);
45 }
46 }
47 }
48 }
49}
50
51VirtualDevice LexicalOnDeviceMixin::GetVirtualDevice(const Expr& expr) const {
52 OnDeviceProps props = GetOnDeviceProps(expr);
53 if (props.body.defined() && props.is_fixed()) {
54 return props.virtual_device;
55 } else if (const auto* var_node = expr.as<VarNode>()) {
56 // Lookup variable binding.
57 auto itr = var_virtual_devices_.find(GetRef<Var>(var_node));
58 if (itr != var_virtual_devices_.end()) {
59 return itr->second;
60 }
61 // else: fallthrough to unconstrained
62 } else if (const auto* global_var_node = expr.as<GlobalVarNode>()) {
63 // Lookup global variable.
64 auto itr = global_var_virtual_devices_.find(GetRef<GlobalVar>(global_var_node));
65 if (itr != global_var_virtual_devices_.end()) {
66 return itr->second;
67 }
68 // else: fallthrough to unconstrained
69 } else if (const auto* function_node = expr.as<FunctionNode>()) {
70 if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
71 if (!expr_virtual_devices_.empty()) {
72 // Use the currently in-scope device type.
73 return expr_virtual_devices_.back();
74 }
75 // else: fallthrough to unconstrained
76 } else {
77 return function_node->virtual_device();
78 }
79 } else {
80 if (!expr_virtual_devices_.empty()) {
81 // Use the currently in-scope device type.
82 return expr_virtual_devices_.back();
83 }
84 // else: fallthrough to unconstrained
85 }
86 return VirtualDevice::FullyUnconstrained();
87}
88
89void LexicalOnDeviceMixin::EnterFunctionBody() { ++function_nesting_; }
90
91void LexicalOnDeviceMixin::ExitFunctionBody() {
92 ICHECK_GT(function_nesting_, 0);
93 --function_nesting_;
94}
95
96void LexicalOnDeviceMixin::PushVirtualDevice(const VirtualDevice& virtual_device) {
97 expr_virtual_devices_.emplace_back(virtual_device);
98}
99
100void LexicalOnDeviceMixin::PopVirtualDevice() {
101 if (expr_virtual_devices_.empty()) {
102 return;
103 }
104 expr_virtual_devices_.pop_back();
105}
106
107void LexicalOnDeviceMixin::PushBoundVar(Var var, const VirtualDevice& virtual_device) {
108 if (virtual_device->IsFullyUnconstrained()) {
109 return;
110 }
111 ICHECK(var_virtual_devices_.find(var) == var_virtual_devices_.end());
112 var_virtual_devices_.emplace(std::move(var), virtual_device);
113}
114
115void LexicalOnDeviceMixin::PopBoundVar(const Var& var) {
116 auto itr = var_virtual_devices_.find(var);
117 if (itr == var_virtual_devices_.end()) {
118 return;
119 }
120 var_virtual_devices_.erase(itr);
121}
122
123// TODO(mbs): We'd probably have less tedious code duplication if we redefined the memoizing
124// mutator on top of the generic Functor.
125
126void DeviceAwareExprVisitor::VisitExpr_(const FunctionNode* function_node) {
127 if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
128 // No tracking inside primitive functions.
129 DeviceAwareVisitExpr_(function_node);
130 } else {
131 // Function parameters come into scope.
132 for (auto param : function_node->params) {
133 PushBoundVar(param, param->virtual_device());
134 }
135 // Entering scope of function body.
136 PushVirtualDevice(function_node->virtual_device());
137 EnterFunctionBody();
138
139 DeviceAwareVisitExpr_(function_node);
140
141 // Leaving scope of function body.
142 ExitFunctionBody();
143 PopVirtualDevice();
144 // Function parameters go out of scope.
145 for (size_t i = 0; i < function_node->params.size(); ++i) {
146 PopBoundVar(function_node->params[i]);
147 }
148 }
149}
150
151void DeviceAwareExprVisitor::VisitExpr_(const LetNode* let_node) {
152 PreVisitLetBlock_(let_node);
153 std::vector<const LetNode*> bindings;
154 Expr expr = GetRef<Expr>(let_node);
155 while (const auto* inner_let_node = expr.as<LetNode>()) {
156 // Let-bound var (in pre visited version) goes into scope.
157 // (We'll just assume this is a letrec).
158 PushBoundVar(inner_let_node->var, GetVirtualDevice(inner_let_node->value));
159 PreVisitLetBinding_(inner_let_node->var, inner_let_node->value);
160 bindings.emplace_back(inner_let_node);
161 expr = inner_let_node->body;
162 }
163
164 VisitExpr(expr);
165
166 for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) {
167 // Let-bound var goes out of scope.
168 PopBoundVar((*itr)->var);
169 PostVisitLet_(*itr);
170 }
171 PostVisitLetBlock_(let_node);
172}
173
174void DeviceAwareExprVisitor::VisitExpr_(const CallNode* call_node) {
175 OnDeviceProps props = GetOnDeviceProps(call_node);
176 if (props.body.defined() && props.is_fixed()) {
177 // Entering lexical scope of fixed "on_device" call.
178 PushVirtualDevice(props.virtual_device);
179 VisitExpr(props.body);
180 // Leaving lexical scope of "on_device" call.
181 PopVirtualDevice();
182 } else {
183 DeviceAwareVisitExpr_(call_node);
184 }
185}
186
187void DeviceAwareExprVisitor::DeviceAwareVisitExpr_(const FunctionNode* function_node) {
188 ExprVisitor::VisitExpr_(function_node);
189}
190
191void DeviceAwareExprVisitor::DeviceAwareVisitExpr_(const CallNode* call_node) {
192 ExprVisitor::VisitExpr_(call_node);
193}
194
195void DeviceAwareExprVisitor::PreVisitLetBlock_(const LetNode* let_node) {
196 // no-op
197}
198
199void DeviceAwareExprVisitor::PreVisitLetBinding_(const Var& var, const Expr& value) {
200 VisitExpr(var);
201 VisitExpr(value);
202}
203
204void DeviceAwareExprVisitor::PostVisitLet_(const LetNode* let_node) {
205 // no-op
206}
207
208void DeviceAwareExprVisitor::PostVisitLetBlock_(const LetNode* let_node) {
209 // no-op
210}
211
212Expr DeviceAwareExprMutator::VisitExpr_(const FunctionNode* function_node) {
213 if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
214 // No tracking inside primitive functions.
215 return DeviceAwareVisitExpr_(function_node);
216 } else {
217 // Function parameters come into scope.
218 for (auto param : function_node->params) {
219 PushBoundVar(param, param->virtual_device());
220 }
221 // Entering scope of function body.
222 PushVirtualDevice(function_node->virtual_device());
223 EnterFunctionBody();
224
225 Expr result = DeviceAwareVisitExpr_(function_node);
226
227 // Leaving scope of function body.
228 ExitFunctionBody();
229 PopVirtualDevice();
230 // Function parameters go out of scope.
231 for (size_t i = 0; i < function_node->params.size(); ++i) {
232 PopBoundVar(function_node->params[i]);
233 }
234
235 return result;
236 }
237}
238
239Expr DeviceAwareExprMutator::VisitExpr_(const LetNode* let_node) {
240 PreVisitLetBlock_(let_node);
241 std::vector<std::tuple<Var, Expr, Span, const LetNode*>> bindings;
242 Expr expr = GetRef<Expr>(let_node);
243 while (const auto* inner_let_node = expr.as<LetNode>()) {
244 // Let-bound var (in pre visited version) goes into scope.
245 // (We'll just assume this is a letrec.)
246 PushBoundVar(inner_let_node->var, GetVirtualDevice(inner_let_node->value));
247 std::pair<Var, Expr> pair = PreVisitLetBinding_(inner_let_node->var, inner_let_node->value);
248 bindings.emplace_back(pair.first, pair.second, inner_let_node->span, inner_let_node);
249 expr = inner_let_node->body;
250 }
251
252 expr = VisitExpr(expr);
253
254 for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) {
255 // Let-bound var goes out of scope.
256 const LetNode* pre_let_node = std::get<3>(*itr);
257 PopBoundVar(pre_let_node->var);
258 Let post_let = Let(/*var=*/std::get<0>(*itr), /*value=*/std::get<1>(*itr),
259 /*body=*/expr, /*span=*/std::get<2>(*itr));
260 expr = PostVisitLet_(pre_let_node, post_let.get());
261 }
262 return PostVisitLetBlock_(let_node, expr.as<LetNode>());
263}
264
265Expr DeviceAwareExprMutator::VisitExpr_(const CallNode* call_node) {
266 OnDeviceProps props = GetOnDeviceProps(call_node);
267 if (props.body.defined() && props.is_fixed()) {
268 // Entering lexical scope of fixed "on_device" call.
269 PushVirtualDevice(props.virtual_device);
270 Expr expr = VisitExpr(props.body);
271 // Leaving lexical scope of "on_device" call.
272 PopVirtualDevice();
273 return MaybeOnDeviceWithProps(expr, props);
274 } else {
275 return DeviceAwareVisitExpr_(call_node);
276 }
277}
278
279Expr DeviceAwareExprMutator::DeviceAwareVisitExpr_(const FunctionNode* function_node) {
280 return ExprMutator::VisitExpr_(function_node);
281}
282
283Expr DeviceAwareExprMutator::DeviceAwareVisitExpr_(const CallNode* call_node) {
284 return ExprMutator::VisitExpr_(call_node);
285}
286
287void DeviceAwareExprMutator::PreVisitLetBlock_(const LetNode* let_node) { /* no-op */
288}
289
290std::pair<Var, Expr> DeviceAwareExprMutator::PreVisitLetBinding_(const Var& var,
291 const Expr& value) {
292 return std::make_pair(Downcast<Var>(VisitExpr(var)), VisitExpr(value));
293}
294
295Expr DeviceAwareExprMutator::PostVisitLet_(const LetNode* pre_let_node,
296 const LetNode* post_let_node) {
297 if (pre_let_node->var == post_let_node->var && pre_let_node->value == post_let_node->value &&
298 pre_let_node->body == post_let_node->body) {
299 return GetRef<Expr>(pre_let_node);
300 } else {
301 return GetRef<Expr>(post_let_node);
302 }
303}
304
305Expr DeviceAwareExprMutator::PostVisitLetBlock_(const LetNode* pre_let_node,
306 const LetNode* post_let_node) {
307 if (pre_let_node->var == post_let_node->var && pre_let_node->value == post_let_node->value &&
308 pre_let_node->body == post_let_node->body) {
309 return GetRef<Expr>(pre_let_node);
310 } else {
311 return GetRef<Expr>(post_let_node);
312 }
313}
314
315std::unordered_map<const ExprNode*, VirtualDevice> RecoverVirtualDeviceMap(const IRModule& mod,
316 const Expr& expr) {
317 class Visitor : public DeviceAwareExprVisitor {
318 public:
319 explicit Visitor(const Optional<IRModule>& maybe_mod) : DeviceAwareExprVisitor(maybe_mod) {}
320
321 void VisitExpr(const Expr& expr) final {
322 if (expr->IsInstance<OpNode>() || expr->IsInstance<ConstructorNode>()) {
323 // Don't record for ops or constructors since they are 'device polymorphic'.
324 } else {
325 map_[expr.get()] = GetVirtualDevice(expr);
326 }
327 DeviceAwareExprVisitor::VisitExpr(expr);
328 }
329
330 std::unordered_map<const ExprNode*, VirtualDevice> map_;
331 };
332
333 Visitor visitor(mod);
334 visitor.VisitExpr(expr);
335 return std::move(visitor.map_);
336}
337
338// Export the helper function for testing.
339TVM_REGISTER_GLOBAL("relay.transform.RecoverVirtualDeviceMap")
340 .set_body_typed([](const IRModule& mod, const Expr& expr) {
341 std::unordered_map<const ExprNode*, VirtualDevice> raw_map =
342 RecoverVirtualDeviceMap(mod, expr);
343 Map<Expr, VirtualDevice> map;
344 for (const auto& kv : raw_map) {
345 map.Set(GetRef<Expr>(kv.first), kv.second);
346 }
347 return map;
348 });
349
350} // namespace transform
351} // namespace relay
352} // namespace tvm
353