1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/transforms/device_aware_visitors.cc |
22 | * \brief Visitors which track the device for the current Relay expression. |
23 | */ |
24 | |
25 | #include "./device_aware_visitors.h" |
26 | |
27 | namespace tvm { |
28 | namespace relay { |
29 | namespace transform { |
30 | |
31 | // TODO(mbs): This machinery can be used a) on expressions/modules which have not had |
32 | // device planning run, and b) on expressions for which we've not kept track of their |
33 | // containing module. For now we'll handle b) by being forgiving as possible when recovering |
34 | // the device for an expression, and we'll support a) the same way. But better would be |
35 | // to ICHECK fail when, eg, a variable is not in scope or the lexical device stack is empty. |
36 | |
37 | LexicalOnDeviceMixin::LexicalOnDeviceMixin(const Optional<IRModule>& maybe_mod) { |
38 | if (maybe_mod) { |
39 | for (const auto& kv : maybe_mod.value()->functions) { |
40 | if (const auto* function_node = kv.second.as<FunctionNode>()) { |
41 | VirtualDevice virtual_device = function_node->virtual_device(); |
42 | if (!virtual_device->IsFullyUnconstrained()) { |
43 | VLOG(2) << "global '" << kv.first->name_hint << "' has virtual device " << virtual_device; |
44 | global_var_virtual_devices_.emplace(kv.first, virtual_device); |
45 | } |
46 | } |
47 | } |
48 | } |
49 | } |
50 | |
51 | VirtualDevice LexicalOnDeviceMixin::GetVirtualDevice(const Expr& expr) const { |
52 | OnDeviceProps props = GetOnDeviceProps(expr); |
53 | if (props.body.defined() && props.is_fixed()) { |
54 | return props.virtual_device; |
55 | } else if (const auto* var_node = expr.as<VarNode>()) { |
56 | // Lookup variable binding. |
57 | auto itr = var_virtual_devices_.find(GetRef<Var>(var_node)); |
58 | if (itr != var_virtual_devices_.end()) { |
59 | return itr->second; |
60 | } |
61 | // else: fallthrough to unconstrained |
62 | } else if (const auto* global_var_node = expr.as<GlobalVarNode>()) { |
63 | // Lookup global variable. |
64 | auto itr = global_var_virtual_devices_.find(GetRef<GlobalVar>(global_var_node)); |
65 | if (itr != global_var_virtual_devices_.end()) { |
66 | return itr->second; |
67 | } |
68 | // else: fallthrough to unconstrained |
69 | } else if (const auto* function_node = expr.as<FunctionNode>()) { |
70 | if (function_node->HasNonzeroAttr(attr::kPrimitive)) { |
71 | if (!expr_virtual_devices_.empty()) { |
72 | // Use the currently in-scope device type. |
73 | return expr_virtual_devices_.back(); |
74 | } |
75 | // else: fallthrough to unconstrained |
76 | } else { |
77 | return function_node->virtual_device(); |
78 | } |
79 | } else { |
80 | if (!expr_virtual_devices_.empty()) { |
81 | // Use the currently in-scope device type. |
82 | return expr_virtual_devices_.back(); |
83 | } |
84 | // else: fallthrough to unconstrained |
85 | } |
86 | return VirtualDevice::FullyUnconstrained(); |
87 | } |
88 | |
89 | void LexicalOnDeviceMixin::EnterFunctionBody() { ++function_nesting_; } |
90 | |
91 | void LexicalOnDeviceMixin::ExitFunctionBody() { |
92 | ICHECK_GT(function_nesting_, 0); |
93 | --function_nesting_; |
94 | } |
95 | |
96 | void LexicalOnDeviceMixin::PushVirtualDevice(const VirtualDevice& virtual_device) { |
97 | expr_virtual_devices_.emplace_back(virtual_device); |
98 | } |
99 | |
100 | void LexicalOnDeviceMixin::PopVirtualDevice() { |
101 | if (expr_virtual_devices_.empty()) { |
102 | return; |
103 | } |
104 | expr_virtual_devices_.pop_back(); |
105 | } |
106 | |
107 | void LexicalOnDeviceMixin::PushBoundVar(Var var, const VirtualDevice& virtual_device) { |
108 | if (virtual_device->IsFullyUnconstrained()) { |
109 | return; |
110 | } |
111 | ICHECK(var_virtual_devices_.find(var) == var_virtual_devices_.end()); |
112 | var_virtual_devices_.emplace(std::move(var), virtual_device); |
113 | } |
114 | |
115 | void LexicalOnDeviceMixin::PopBoundVar(const Var& var) { |
116 | auto itr = var_virtual_devices_.find(var); |
117 | if (itr == var_virtual_devices_.end()) { |
118 | return; |
119 | } |
120 | var_virtual_devices_.erase(itr); |
121 | } |
122 | |
123 | // TODO(mbs): We'd probably have less tedious code duplication if we redefined the memoizing |
124 | // mutator on top of the generic Functor. |
125 | |
126 | void DeviceAwareExprVisitor::VisitExpr_(const FunctionNode* function_node) { |
127 | if (function_node->HasNonzeroAttr(attr::kPrimitive)) { |
128 | // No tracking inside primitive functions. |
129 | DeviceAwareVisitExpr_(function_node); |
130 | } else { |
131 | // Function parameters come into scope. |
132 | for (auto param : function_node->params) { |
133 | PushBoundVar(param, param->virtual_device()); |
134 | } |
135 | // Entering scope of function body. |
136 | PushVirtualDevice(function_node->virtual_device()); |
137 | EnterFunctionBody(); |
138 | |
139 | DeviceAwareVisitExpr_(function_node); |
140 | |
141 | // Leaving scope of function body. |
142 | ExitFunctionBody(); |
143 | PopVirtualDevice(); |
144 | // Function parameters go out of scope. |
145 | for (size_t i = 0; i < function_node->params.size(); ++i) { |
146 | PopBoundVar(function_node->params[i]); |
147 | } |
148 | } |
149 | } |
150 | |
151 | void DeviceAwareExprVisitor::VisitExpr_(const LetNode* let_node) { |
152 | PreVisitLetBlock_(let_node); |
153 | std::vector<const LetNode*> bindings; |
154 | Expr expr = GetRef<Expr>(let_node); |
155 | while (const auto* inner_let_node = expr.as<LetNode>()) { |
156 | // Let-bound var (in pre visited version) goes into scope. |
157 | // (We'll just assume this is a letrec). |
158 | PushBoundVar(inner_let_node->var, GetVirtualDevice(inner_let_node->value)); |
159 | PreVisitLetBinding_(inner_let_node->var, inner_let_node->value); |
160 | bindings.emplace_back(inner_let_node); |
161 | expr = inner_let_node->body; |
162 | } |
163 | |
164 | VisitExpr(expr); |
165 | |
166 | for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { |
167 | // Let-bound var goes out of scope. |
168 | PopBoundVar((*itr)->var); |
169 | PostVisitLet_(*itr); |
170 | } |
171 | PostVisitLetBlock_(let_node); |
172 | } |
173 | |
174 | void DeviceAwareExprVisitor::VisitExpr_(const CallNode* call_node) { |
175 | OnDeviceProps props = GetOnDeviceProps(call_node); |
176 | if (props.body.defined() && props.is_fixed()) { |
177 | // Entering lexical scope of fixed "on_device" call. |
178 | PushVirtualDevice(props.virtual_device); |
179 | VisitExpr(props.body); |
180 | // Leaving lexical scope of "on_device" call. |
181 | PopVirtualDevice(); |
182 | } else { |
183 | DeviceAwareVisitExpr_(call_node); |
184 | } |
185 | } |
186 | |
187 | void DeviceAwareExprVisitor::DeviceAwareVisitExpr_(const FunctionNode* function_node) { |
188 | ExprVisitor::VisitExpr_(function_node); |
189 | } |
190 | |
191 | void DeviceAwareExprVisitor::DeviceAwareVisitExpr_(const CallNode* call_node) { |
192 | ExprVisitor::VisitExpr_(call_node); |
193 | } |
194 | |
195 | void DeviceAwareExprVisitor::PreVisitLetBlock_(const LetNode* let_node) { |
196 | // no-op |
197 | } |
198 | |
199 | void DeviceAwareExprVisitor::PreVisitLetBinding_(const Var& var, const Expr& value) { |
200 | VisitExpr(var); |
201 | VisitExpr(value); |
202 | } |
203 | |
204 | void DeviceAwareExprVisitor::PostVisitLet_(const LetNode* let_node) { |
205 | // no-op |
206 | } |
207 | |
208 | void DeviceAwareExprVisitor::PostVisitLetBlock_(const LetNode* let_node) { |
209 | // no-op |
210 | } |
211 | |
212 | Expr DeviceAwareExprMutator::VisitExpr_(const FunctionNode* function_node) { |
213 | if (function_node->HasNonzeroAttr(attr::kPrimitive)) { |
214 | // No tracking inside primitive functions. |
215 | return DeviceAwareVisitExpr_(function_node); |
216 | } else { |
217 | // Function parameters come into scope. |
218 | for (auto param : function_node->params) { |
219 | PushBoundVar(param, param->virtual_device()); |
220 | } |
221 | // Entering scope of function body. |
222 | PushVirtualDevice(function_node->virtual_device()); |
223 | EnterFunctionBody(); |
224 | |
225 | Expr result = DeviceAwareVisitExpr_(function_node); |
226 | |
227 | // Leaving scope of function body. |
228 | ExitFunctionBody(); |
229 | PopVirtualDevice(); |
230 | // Function parameters go out of scope. |
231 | for (size_t i = 0; i < function_node->params.size(); ++i) { |
232 | PopBoundVar(function_node->params[i]); |
233 | } |
234 | |
235 | return result; |
236 | } |
237 | } |
238 | |
239 | Expr DeviceAwareExprMutator::VisitExpr_(const LetNode* let_node) { |
240 | PreVisitLetBlock_(let_node); |
241 | std::vector<std::tuple<Var, Expr, Span, const LetNode*>> bindings; |
242 | Expr expr = GetRef<Expr>(let_node); |
243 | while (const auto* inner_let_node = expr.as<LetNode>()) { |
244 | // Let-bound var (in pre visited version) goes into scope. |
245 | // (We'll just assume this is a letrec.) |
246 | PushBoundVar(inner_let_node->var, GetVirtualDevice(inner_let_node->value)); |
247 | std::pair<Var, Expr> pair = PreVisitLetBinding_(inner_let_node->var, inner_let_node->value); |
248 | bindings.emplace_back(pair.first, pair.second, inner_let_node->span, inner_let_node); |
249 | expr = inner_let_node->body; |
250 | } |
251 | |
252 | expr = VisitExpr(expr); |
253 | |
254 | for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { |
255 | // Let-bound var goes out of scope. |
256 | const LetNode* pre_let_node = std::get<3>(*itr); |
257 | PopBoundVar(pre_let_node->var); |
258 | Let post_let = Let(/*var=*/std::get<0>(*itr), /*value=*/std::get<1>(*itr), |
259 | /*body=*/expr, /*span=*/std::get<2>(*itr)); |
260 | expr = PostVisitLet_(pre_let_node, post_let.get()); |
261 | } |
262 | return PostVisitLetBlock_(let_node, expr.as<LetNode>()); |
263 | } |
264 | |
265 | Expr DeviceAwareExprMutator::VisitExpr_(const CallNode* call_node) { |
266 | OnDeviceProps props = GetOnDeviceProps(call_node); |
267 | if (props.body.defined() && props.is_fixed()) { |
268 | // Entering lexical scope of fixed "on_device" call. |
269 | PushVirtualDevice(props.virtual_device); |
270 | Expr expr = VisitExpr(props.body); |
271 | // Leaving lexical scope of "on_device" call. |
272 | PopVirtualDevice(); |
273 | return MaybeOnDeviceWithProps(expr, props); |
274 | } else { |
275 | return DeviceAwareVisitExpr_(call_node); |
276 | } |
277 | } |
278 | |
279 | Expr DeviceAwareExprMutator::DeviceAwareVisitExpr_(const FunctionNode* function_node) { |
280 | return ExprMutator::VisitExpr_(function_node); |
281 | } |
282 | |
283 | Expr DeviceAwareExprMutator::DeviceAwareVisitExpr_(const CallNode* call_node) { |
284 | return ExprMutator::VisitExpr_(call_node); |
285 | } |
286 | |
287 | void DeviceAwareExprMutator::PreVisitLetBlock_(const LetNode* let_node) { /* no-op */ |
288 | } |
289 | |
290 | std::pair<Var, Expr> DeviceAwareExprMutator::PreVisitLetBinding_(const Var& var, |
291 | const Expr& value) { |
292 | return std::make_pair(Downcast<Var>(VisitExpr(var)), VisitExpr(value)); |
293 | } |
294 | |
295 | Expr DeviceAwareExprMutator::PostVisitLet_(const LetNode* pre_let_node, |
296 | const LetNode* post_let_node) { |
297 | if (pre_let_node->var == post_let_node->var && pre_let_node->value == post_let_node->value && |
298 | pre_let_node->body == post_let_node->body) { |
299 | return GetRef<Expr>(pre_let_node); |
300 | } else { |
301 | return GetRef<Expr>(post_let_node); |
302 | } |
303 | } |
304 | |
305 | Expr DeviceAwareExprMutator::PostVisitLetBlock_(const LetNode* pre_let_node, |
306 | const LetNode* post_let_node) { |
307 | if (pre_let_node->var == post_let_node->var && pre_let_node->value == post_let_node->value && |
308 | pre_let_node->body == post_let_node->body) { |
309 | return GetRef<Expr>(pre_let_node); |
310 | } else { |
311 | return GetRef<Expr>(post_let_node); |
312 | } |
313 | } |
314 | |
315 | std::unordered_map<const ExprNode*, VirtualDevice> RecoverVirtualDeviceMap(const IRModule& mod, |
316 | const Expr& expr) { |
317 | class Visitor : public DeviceAwareExprVisitor { |
318 | public: |
319 | explicit Visitor(const Optional<IRModule>& maybe_mod) : DeviceAwareExprVisitor(maybe_mod) {} |
320 | |
321 | void VisitExpr(const Expr& expr) final { |
322 | if (expr->IsInstance<OpNode>() || expr->IsInstance<ConstructorNode>()) { |
323 | // Don't record for ops or constructors since they are 'device polymorphic'. |
324 | } else { |
325 | map_[expr.get()] = GetVirtualDevice(expr); |
326 | } |
327 | DeviceAwareExprVisitor::VisitExpr(expr); |
328 | } |
329 | |
330 | std::unordered_map<const ExprNode*, VirtualDevice> map_; |
331 | }; |
332 | |
333 | Visitor visitor(mod); |
334 | visitor.VisitExpr(expr); |
335 | return std::move(visitor.map_); |
336 | } |
337 | |
338 | // Export the helper function for testing. |
339 | TVM_REGISTER_GLOBAL("relay.transform.RecoverVirtualDeviceMap" ) |
340 | .set_body_typed([](const IRModule& mod, const Expr& expr) { |
341 | std::unordered_map<const ExprNode*, VirtualDevice> raw_map = |
342 | RecoverVirtualDeviceMap(mod, expr); |
343 | Map<Expr, VirtualDevice> map; |
344 | for (const auto& kv : raw_map) { |
345 | map.Set(GetRef<Expr>(kv.first), kv.second); |
346 | } |
347 | return map; |
348 | }); |
349 | |
350 | } // namespace transform |
351 | } // namespace relay |
352 | } // namespace tvm |
353 | |