1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/transforms/dead_code.cc |
22 | * \brief Elides or inlines let-bindings. |
23 | * |
24 | * TODO(mbs): Track dead writes into references. |
25 | */ |
26 | |
27 | #include <tvm/relay/analysis.h> |
28 | #include <tvm/relay/expr_functor.h> |
29 | #include <tvm/relay/pattern_functor.h> |
30 | #include <tvm/relay/transform.h> |
31 | |
32 | #include "../op/call/call.h" |
33 | |
34 | namespace tvm { |
35 | namespace relay { |
36 | namespace { |
37 | |
38 | /*! \brief Maximum depth of calls to analyize. */ |
39 | constexpr int kMaxCallDepth = 25; |
40 | |
41 | /*! |
42 | * \brief Captures (an approximation of) the purity for a Relay sub-expression. A pure |
43 | * sub-expression is guaranteed never to access or mutate state. Thus the sub-expression |
44 | * can safely be elided (if its result is never used), or inlined (which may change the |
45 | * number of times and program order for the evaluation.) |
46 | */ |
47 | struct Purity { |
48 | /*! |
49 | * \brief True if evaling the sub-expression itself is pure. |
50 | */ |
51 | bool pure_eval; |
52 | /*! |
53 | * \brief If the sub-expression is first-order then always true. Otherwise true only if evaling |
54 | * a call to the sub-expression is pure. See [RULE A] below. |
55 | */ |
56 | bool pure_call; |
57 | }; |
58 | |
59 | /*! |
60 | * \brief Visits all the global functions in a module and records the purity of every let-bound |
61 | * value. |
62 | * |
63 | * (See also inline.cc for function inlining.) |
64 | * |
65 | * Generally we track whether evaluation of a sub-expression is definitely pure. However for |
66 | * sub-expressions f of higher-order type we also track the 'call purity' of evaling a call to f: |
67 | * - [RULE A] If f's result is itself higher-order then f is call-pure only if the result of f is |
68 | * also call-pure. |
69 | * - [RULE B] Higher-order function arguments are assumed call impure. |
70 | * - [RULE C] We assume functions extracted from tuples are call impure. |
71 | * - [RULE D] We assume functions extracted from references are call impure. |
72 | * - [RULE E] We assume functions extracted from ADTs are call impure. |
73 | * - [RULE F] We assume all external Functions and PrimFuncs are call impure. |
74 | */ |
75 | class PurityVisitor : ExprFunctor<Purity(const Expr&)> { |
76 | public: |
77 | explicit PurityVisitor(IRModule mod) : mod_(std::move(mod)), current_call_depth_(0) {} |
78 | |
79 | /*! \brief Visit all the functions in the module. */ |
80 | void VisitModule() { |
81 | VLOG_CONTEXT << "PurityVisitor" ; |
82 | // It is safe to visit the global functions in any order. Recursive global functions are |
83 | // allowed. |
84 | for (const auto& kv : mod_->functions) { |
85 | if (const auto* function_node = kv.second.as<FunctionNode>()) { |
86 | if (function_node->HasNonzeroAttr(attr::kPrimitive) || |
87 | function_node->HasNonzeroAttr(attr::kExtern)) { |
88 | // Ignore primitive and external functions. |
89 | continue; |
90 | } |
91 | // Everything of interest will be recorded in the purity maps so we ignore the result. |
92 | (void)VisitGlobalFunction(kv.first, GetRef<Function>(function_node)); |
93 | } |
94 | } |
95 | } |
96 | |
97 | /*! |
98 | * \brief Returns a map from every let-bound variable to whether its let-bound value is |
99 | * definitely pure. |
100 | */ |
101 | std::unordered_map<const VarNode*, bool> GetPurityMap() const { |
102 | std::unordered_map<const VarNode*, bool> result; |
103 | for (const auto& kv : var_to_purity_) { |
104 | result.emplace(kv.first, kv.second.pure_eval); |
105 | } |
106 | return result; |
107 | } |
108 | |
109 | private: |
110 | Purity VisitExpr(const Expr& expr) final { |
111 | auto it = memo_.find(expr.get()); |
112 | if (it != this->memo_.end()) { |
113 | return it->second; |
114 | } else { |
115 | Purity result = ExprFunctor::VisitExpr(expr); |
116 | memo_[expr.get()] = result; |
117 | return result; |
118 | } |
119 | } |
120 | |
121 | Purity VisitExpr_(const ConstantNode*) final { return {/*pure_eval=*/true, /*pure_call=*/true}; } |
122 | |
123 | Purity VisitExpr_(const ConstructorNode*) final { |
124 | return {/*pure_eval=*/true, /*pure_call=*/true}; |
125 | } |
126 | |
127 | Purity VisitExpr_(const OpNode* op_node) final { |
128 | // Primitive operators are pure unless marked as 'stateful'. |
129 | static OpAttrMap<bool> attr_map = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful" ); |
130 | bool is_stateful = attr_map.count(GetRef<Op>(op_node)) && attr_map[GetRef<Op>(op_node)]; |
131 | return {/*pure_eval=*/true, /*pure_call=*/!is_stateful}; |
132 | } |
133 | |
134 | Purity VisitExpr_(const GlobalVarNode* global_var_node) final { |
135 | auto global_var = GetRef<GlobalVar>(global_var_node); |
136 | ICHECK(mod_->ContainGlobalVar(global_var_node->name_hint)) |
137 | << "No definition for '" << global_var_node->name_hint << "'" ; |
138 | auto func = mod_->Lookup(global_var); |
139 | if (const auto* function_node = func.as<FunctionNode>()) { |
140 | if (!function_node->HasNonzeroAttr(attr::kExtern)) { |
141 | return VisitGlobalFunction(global_var, GetRef<Function>(function_node)); |
142 | } |
143 | } |
144 | // Assume externals and PrimFuncs are call-impure [RULE F]. |
145 | // (If they are pure then we should have dealt with them before lowering.) |
146 | return {/*pure_eval==*/true, /*pure_call=*/false}; |
147 | } |
148 | |
149 | Purity VisitExpr_(const VarNode* var_node) final { |
150 | // The var is bound to a value, but if that value is a function we need to propagate the |
151 | // function body's purity. |
152 | ICHECK(var_to_purity_.count(var_node)) << PrettyPrint(GetRef<Var>(var_node)); |
153 | return {/*pure_eval=*/true, /*pure_call=*/var_to_purity_[var_node].pure_call}; |
154 | } |
155 | |
156 | Purity VisitExpr_(const FunctionNode* function_node) final { |
157 | for (const auto& param : function_node->params) { |
158 | // Any higher-order parameters are assumed to be call-impure [RULE B] |
159 | var_to_purity_[param.get()] = {/*pure_eval=*/true, /*pure_call=*/IsFirstOrder(param)}; |
160 | } |
161 | Purity body_purity = VisitExpr(function_node->body); |
162 | // The function itself is a value and thus pure. If the function returns |
163 | // a function we'll fold its purity in here [RULE A] |
164 | return {/*pure_eval=*/true, /*pure_call=*/body_purity.pure_eval && body_purity.pure_call}; |
165 | } |
166 | |
167 | Purity VisitExpr_(const LetNode* let_node) final { |
168 | Expr expr = GetRef<Expr>(let_node); |
169 | bool all_values_pure_eval = true; |
170 | while (const auto* inner_let_node = expr.as<LetNode>()) { |
171 | // In case the value is a recursive function assume the let-bound variable is call-pure. |
172 | var_to_purity_[inner_let_node->var.get()] = {/*pure_eval=*/true, /*pure_call=*/true}; |
173 | Purity value_purity = VisitExpr(inner_let_node->value); |
174 | // Now revise the variable to it's true purity. |
175 | var_to_purity_[inner_let_node->var.get()] = value_purity; |
176 | VLOG(2) << (value_purity.pure_eval ? "pure" : "impure" ) << " expression:" << std::endl |
177 | << PrettyPrint(inner_let_node->value) << std::endl |
178 | << "let-bound to variable:" << std::endl |
179 | << PrettyPrint(inner_let_node->var); |
180 | all_values_pure_eval = all_values_pure_eval && value_purity.pure_eval; |
181 | expr = inner_let_node->body; |
182 | } |
183 | Purity body_purity = VisitExpr(expr); |
184 | return {/*pure_eval=*/all_values_pure_eval && body_purity.pure_eval, |
185 | /*pure_call=*/body_purity.pure_call}; |
186 | } |
187 | |
188 | Purity VisitExpr_(const CallNode* call_node) final { |
189 | auto call = GetRef<Call>(call_node); |
190 | if (current_call_depth_ >= kMaxCallDepth) { |
191 | // Assume impure. |
192 | VLOG(2) << "assuming call is impure since too deeply nested" ; |
193 | return {/*pure_eval=*/false, /*pure_call*/ IsFirstOrder(call)}; |
194 | } |
195 | |
196 | ++current_call_depth_; |
197 | |
198 | // We can work with calls in both pre- and post-lowered form. |
199 | Call vanilla_call = GetAnyCall(call_node); |
200 | |
201 | // Find purity for the callee and the args. |
202 | Purity callee_purity = VisitExpr(vanilla_call->op); |
203 | bool all_args_pure_eval = true; |
204 | for (const auto& arg : vanilla_call->args) { |
205 | Purity arg_purity = VisitExpr(arg); |
206 | all_args_pure_eval = all_args_pure_eval && arg_purity.pure_eval; |
207 | } |
208 | |
209 | VLOG(2) << (callee_purity.pure_call ? "pure" : "impure" ) << " call to:" << std::endl |
210 | << PrettyPrint(vanilla_call->op); |
211 | |
212 | ICHECK_GT(current_call_depth_, 0); |
213 | --current_call_depth_; |
214 | |
215 | // If the callee's result is itself a function then by [RULE A] its purity |
216 | // is given by callee_purity.pure_call. |
217 | return {/*pure_eval=*/all_args_pure_eval && callee_purity.pure_eval && callee_purity.pure_call, |
218 | /*pure_call=*/IsFirstOrder(call) || callee_purity.pure_call}; |
219 | } |
220 | |
221 | Purity VisitExpr_(const IfNode* if_node) final { |
222 | Purity cond_purity = VisitExpr(if_node->cond); |
223 | ICHECK(cond_purity.pure_call); // conditional is first-order |
224 | Purity true_purity = VisitExpr(if_node->true_branch); |
225 | Purity false_purity = VisitExpr(if_node->false_branch); |
226 | return {/*pure_eval=*/cond_purity.pure_eval && true_purity.pure_eval && false_purity.pure_eval, |
227 | /*pure_call=*/true_purity.pure_call && false_purity.pure_call}; |
228 | } |
229 | |
230 | Purity VisitExpr_(const TupleNode* tuple_node) final { |
231 | bool all_fields_pure = true; |
232 | for (const auto& field : tuple_node->fields) { |
233 | // The call purity of each tuple field is lost [RULE C]. |
234 | Purity field_purity = VisitExpr(field); |
235 | if (!field_purity.pure_eval) { |
236 | all_fields_pure = false; |
237 | } |
238 | } |
239 | return {/*pure_eval=*/all_fields_pure, /*pure_call=*/true}; |
240 | } |
241 | |
242 | Purity VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final { |
243 | Purity tuple_purity = VisitExpr(tuple_get_item_node->tuple); |
244 | ICHECK(tuple_purity.pure_call); // tuple is first-order |
245 | // We don't track call purity through tuple fields, so if the result is a function type we |
246 | // must assume it is call impure [RULE C]. |
247 | return {/*pure_eval=*/tuple_purity.pure_eval, |
248 | /*pure_call=*/IsFirstOrder(GetRef<TupleGetItem>(tuple_get_item_node))}; |
249 | } |
250 | |
251 | Purity VisitExpr_(const RefCreateNode*) final { |
252 | // The creation of the ref itself is unobservable other than via the reads/writes into it. |
253 | return {/*pure_eval=*/true, /*pure_call=*/true}; |
254 | } |
255 | |
256 | Purity VisitExpr_(const RefWriteNode* ref_write_node) final { |
257 | Purity ref_purity = VisitExpr(ref_write_node->ref); |
258 | ICHECK(ref_purity.pure_call); // reference is first-order |
259 | // The call purity of the written value is lost [RULE D]. |
260 | // (But we must still visit to accumulate purity for any let-bindings within in.) |
261 | (void)VisitExpr(ref_write_node->value); |
262 | return {/*pure_eval=*/false, /*pure_call=*/true}; |
263 | } |
264 | |
265 | Purity VisitExpr_(const RefReadNode* ref_read_node) final { |
266 | Purity ref_purity = VisitExpr(ref_read_node->ref); |
267 | ICHECK(ref_purity.pure_call); // reference is first-order |
268 | // We don't track call purity through reference values, so if the result is a function |
269 | // type we must assume it is call impure [RULE D]. |
270 | return {/*pure_eval=*/false, /*pure_call=*/IsFirstOrder(GetRef<RefRead>(ref_read_node))}; |
271 | } |
272 | |
273 | class PurityPatternVisitor : public PatternVisitor { |
274 | public: |
275 | explicit PurityPatternVisitor(PurityVisitor* outer) : outer_(outer) {} |
276 | |
277 | private: |
278 | void VisitPattern_(const PatternVarNode* pattern_var_node) final { |
279 | // We don't track call purity through ADTs, so if var is a function type we must assume |
280 | // it is call impure [RULE E]. |
281 | outer_->var_to_purity_[pattern_var_node->var.get()] = { |
282 | /*pure_eval=*/true, /*pure_call=*/IsFirstOrder(pattern_var_node->var)}; |
283 | } |
284 | |
285 | /*! \brief (Mutable borrow of) the outer visitor. */ |
286 | PurityVisitor* outer_; |
287 | }; |
288 | |
289 | Purity VisitExpr_(const MatchNode* match_node) final { |
290 | Purity data_purity = VisitExpr(match_node->data); |
291 | ICHECK(data_purity.pure_call); // ADT is first order |
292 | bool all_clauses_pure_eval = true; |
293 | bool all_clauses_pure_call = true; |
294 | for (const auto& clause : match_node->clauses) { |
295 | PurityPatternVisitor pattern_visitor(this); |
296 | pattern_visitor.VisitPattern(clause->lhs); |
297 | Purity rhs_purity = VisitExpr(clause->rhs); |
298 | all_clauses_pure_eval = all_clauses_pure_eval && rhs_purity.pure_eval; |
299 | all_clauses_pure_call = all_clauses_pure_call && rhs_purity.pure_call; |
300 | } |
301 | return {/*pure_eval=*/data_purity.pure_eval && all_clauses_pure_eval, |
302 | /*pure_call=*/all_clauses_pure_call}; |
303 | } |
304 | |
305 | /*! \brief Visits \p func bound to global \p var and returns it's purity. */ |
306 | Purity VisitGlobalFunction(const GlobalVar& var, const Function& func) { |
307 | VLOG_CONTEXT << "func " << var->name_hint; |
308 | VLOG(2) << "visiting" ; |
309 | auto itr = global_var_to_purity_.find(var.get()); |
310 | if (itr != global_var_to_purity_.end()) { |
311 | // We've already visited the function body. |
312 | return itr->second; |
313 | } |
314 | // We are entering the body of a possibly-recursive global function. Assume it's body is pure. |
315 | global_var_to_purity_[var.get()] = {/*pure_eval=*/true, /*pure_call=*/true}; |
316 | // Visit the global function for the first time. |
317 | Purity func_purity = VisitExpr(func); |
318 | // Update with the true purity. |
319 | global_var_to_purity_[var.get()] = func_purity; |
320 | return func_purity; |
321 | } |
322 | |
323 | static bool IsFirstOrder(const Expr& expr) { |
324 | return expr->checked_type().as<FuncTypeNode>() == nullptr; |
325 | } |
326 | |
327 | /*! \brief The module we're analyzing. */ |
328 | IRModule mod_; |
329 | |
330 | /*! |
331 | * \brief Maps each let-bound and global variable to the purity of the value it is bound to. |
332 | * If the variable is bound to a function then the purity of saturating that function is also |
333 | * tracked. |
334 | * |
335 | * Note that global_var_to_purity_, and all the 'pure_call' fields, are only needed internally |
336 | * during the analysis, andonly the var_to_purity_ 'pure_eval' fields are used downstream. |
337 | */ |
338 | std::unordered_map<const VarNode*, Purity> var_to_purity_; |
339 | std::unordered_map<const GlobalVarNode*, Purity> global_var_to_purity_; |
340 | |
341 | /*! \brief The current call depth. We'll just assume deeply nested calls are impure rather than |
342 | * spending all that time to check for sure. A deeply nested call is almost certain to be needed |
343 | * anyway. |
344 | */ |
345 | |
346 | int current_call_depth_; |
347 | |
348 | /*! \brief Internal map used for memoization. */ |
349 | std::unordered_map<const ExprNode*, Purity> memo_; |
350 | }; |
351 | |
352 | /*! |
353 | * \brief Accumulate the bound values and usage count for each let-bound variable. |
354 | * |
355 | * We don't attempt to track the number of calls to local functions, and instead just assume they |
356 | * are called at least twice. |
357 | */ |
358 | class UsageVisitor : public ExprVisitor { |
359 | public: |
360 | /*! \brief Accumulates the expression bound to every let-bound variable. */ |
361 | std::unordered_map<const VarNode*, Expr> let_bound_values_; |
362 | /*! \brief Accumulates the usage count for every let-bound variable. */ |
363 | std::unordered_map<const VarNode*, size_t> use_map_; |
364 | |
365 | explicit UsageVisitor(const std::unordered_map<const VarNode*, bool>* var_to_purity, |
366 | bool default_purity) |
367 | : var_to_purity_(var_to_purity), default_purity_(default_purity) {} |
368 | |
369 | void VisitExpr(const Expr& expr) final { |
370 | // Once we've seen 2 usages of a variable we know it can be neither elided nor inlined, |
371 | // so can stop visiting again. |
372 | if (++visit_counter_[expr.get()] <= 2) { |
373 | ExprFunctor<void(const Expr&)>::VisitExpr(expr); |
374 | } |
375 | } |
376 | |
377 | void VisitExpr_(const FunctionNode* function_node) final { |
378 | ++current_scope_level_; |
379 | ExprVisitor::VisitExpr_(function_node); |
380 | ICHECK_GT(current_scope_level_, 0); |
381 | --current_scope_level_; |
382 | } |
383 | |
384 | void VisitExpr_(const LetNode* let_node) final { |
385 | Expr expr = GetRef<Expr>(let_node); |
386 | while (const auto* inner_let_node = expr.as<LetNode>()) { |
387 | ++visit_counter_[inner_let_node]; |
388 | let_bound_values_[inner_let_node->var.get()] = inner_let_node->value; |
389 | VLOG(2) << "seen let-binding for:" << std::endl << PrettyPrint(inner_let_node->var); |
390 | use_map_[inner_let_node->var.get()] = 0; |
391 | scope_level_map_[inner_let_node->var.get()] = current_scope_level_; |
392 | if (is_pure(inner_let_node->var.get())) { |
393 | // We'll defer visiting the let-bound value until we've seen the first use of the let-bound |
394 | // variable and thus know it must be evaluated. |
395 | // no-op. |
396 | } else { |
397 | // The let-bound value is impure so must always be evaluated. Visit now. |
398 | VisitExpr(inner_let_node->value); |
399 | } |
400 | expr = inner_let_node->body; |
401 | } |
402 | VisitExpr(expr); |
403 | } |
404 | |
405 | void VisitExpr_(const VarNode* var_node) final { |
406 | if (let_bound_values_.count(var_node)) { |
407 | size_t& n = use_map_[var_node]; |
408 | ++n; |
409 | VLOG(2) << var_node->name_hint() << " = " << n; |
410 | if (n == 1 && is_pure(var_node)) { |
411 | // Now that we have at least one use of the let-bound var, we know the let-bound |
412 | // value is necessary. |
413 | VisitExpr(let_bound_values_[var_node]); |
414 | } |
415 | if (scope_level_map_[var_node] < current_scope_level_) { |
416 | // Since the variable was bound outside of the current local function, assume the |
417 | // function will be called at least twice. |
418 | ++n; |
419 | VLOG(2) << var_node->name_hint() << " = " << n << " (bound at level " |
420 | << scope_level_map_[var_node] << " but used at level " << current_scope_level_ |
421 | << ")" ; |
422 | } |
423 | } |
424 | // else: nothing to be done for function parameters or variable in match patterns. |
425 | } |
426 | |
427 | bool is_pure(const VarNode* var_node) const { |
428 | auto itr = var_to_purity_->find(var_node); |
429 | return itr == var_to_purity_->end() ? default_purity_ : itr->second; |
430 | } |
431 | |
432 | /*! \brief (Immutable borrow of) the already determined purity for every let-bound variable. */ |
433 | const std::unordered_map<const VarNode*, bool>* var_to_purity_; |
434 | /*! \brief The default purity for variables which are not in the above map. */ |
435 | bool default_purity_; |
436 | /*! |
437 | * \brief The current scope level. 0 for global functions. Incremented by one within each |
438 | * let-bound local function. Necessary so we can avoid inlining an expensive let-bound computation |
439 | * into a function which could be called more than once. |
440 | */ |
441 | int current_scope_level_ = 0; |
442 | /*! \brief Accumulates the scope level for every let-bound variable. */ |
443 | std::unordered_map<const VarNode*, int> scope_level_map_; |
444 | }; |
445 | |
446 | /*! \brief Eliminate/inline let-bound values when sound to do so. */ |
447 | class EliminatorMutator : public ExprMutator { |
448 | public: |
449 | EliminatorMutator(bool inline_once, |
450 | const std::unordered_map<const VarNode*, Expr>* let_bound_values, |
451 | const std::unordered_map<const VarNode*, size_t>* use_map, |
452 | const std::unordered_map<const VarNode*, bool>* var_to_purity, |
453 | bool default_purity) |
454 | : inline_once_(inline_once), |
455 | let_bound_values_(let_bound_values), |
456 | use_map_(use_map), |
457 | var_to_purity_(var_to_purity), |
458 | default_purity_(default_purity) {} |
459 | |
460 | private: |
461 | enum Action { kElide, kInline, kNoChange }; |
462 | |
463 | /*! \brief What should we do with let-binding for \p var_node? */ |
464 | Action ActionFor(const VarNode* var_node) { |
465 | if (let_bound_values_->count(var_node) == 0) { |
466 | // Not let-bound var. |
467 | return kNoChange; |
468 | } |
469 | if (!is_pure(var_node)) { |
470 | // The let-bound value is impure -- we must leave it exactly where it is. |
471 | return kNoChange; |
472 | } |
473 | switch (use_map_->count(var_node) ? use_map_->at(var_node) : 0) { |
474 | case 0: |
475 | return kElide; |
476 | case 1: |
477 | return inline_once_ ? kInline : kNoChange; |
478 | default: |
479 | return kNoChange; |
480 | } |
481 | } |
482 | |
483 | Expr VisitExpr_(const VarNode* var_node) final { |
484 | if (ActionFor(var_node) == kInline) { |
485 | VLOG(1) << "inlining let-bound variable:" << std::endl << PrettyPrint(GetRef<Var>(var_node)); |
486 | return VisitExpr(let_bound_values_->at(var_node)); |
487 | } else { |
488 | return GetRef<Var>(var_node); |
489 | } |
490 | } |
491 | |
492 | Expr VisitExpr_(const LetNode* op) final { |
493 | auto pre_visit = [this](const LetNode* op) { |
494 | if (ActionFor(op->var.get()) != kElide) { |
495 | (void)VisitExpr(op->value); |
496 | } |
497 | }; |
498 | auto post_visit = [this](const LetNode* op) { |
499 | Expr body = VisitExpr(op->body); |
500 | auto expr = GetRef<Expr>(op); |
501 | switch (ActionFor(op->var.get())) { |
502 | case kElide: |
503 | VLOG(1) << "eliding let-bound variable:" << std::endl << PrettyPrint(op->var); |
504 | memo_[expr] = body; |
505 | break; |
506 | case kInline: |
507 | // Already inlined at use-side. |
508 | memo_[expr] = body; |
509 | break; |
510 | case kNoChange: |
511 | Expr value = VisitExpr(op->value); |
512 | memo_[expr] = Let(op->var, value, body); |
513 | break; |
514 | } |
515 | }; |
516 | ExpandANormalForm(op, pre_visit, post_visit); |
517 | return memo_[GetRef<Expr>(op)]; |
518 | } |
519 | |
520 | bool is_pure(const VarNode* var_node) const { |
521 | auto itr = var_to_purity_->find(var_node); |
522 | return itr == var_to_purity_->end() ? default_purity_ : itr->second; |
523 | } |
524 | |
525 | bool inline_once_; |
526 | const std::unordered_map<const VarNode*, Expr>* let_bound_values_; |
527 | const std::unordered_map<const VarNode*, size_t>* use_map_; |
528 | const std::unordered_map<const VarNode*, bool>* var_to_purity_; |
529 | bool default_purity_; |
530 | }; |
531 | |
532 | } // namespace |
533 | |
534 | namespace transform { |
535 | |
536 | // Declared in relay/transform.h |
537 | Pass DeadCodeElimination(bool inline_once, bool ignore_impurity) { |
538 | auto pass_func = [=](IRModule mod, PassContext pc) -> IRModule { |
539 | VLOG(1) << "Before:" << std::endl << PrettyPrint(mod); |
540 | // Which let bindings are pure and can be safely elided? |
541 | std::unordered_map<const VarNode*, bool> var_to_purity; |
542 | if (!ignore_impurity) { |
543 | VLOG(1) << "determine purity" ; |
544 | PurityVisitor purity_visitor(mod); |
545 | purity_visitor.VisitModule(); |
546 | var_to_purity = purity_visitor.GetPurityMap(); |
547 | } |
548 | |
549 | IRModule result(/*functions=*/{}, mod->type_definitions, mod->Imports(), mod->source_map, |
550 | mod->attrs); |
551 | for (const auto& kv : mod->functions) { |
552 | if (const auto* function_node = kv.second.as<FunctionNode>()) { |
553 | auto function = GetRef<Function>(function_node); |
554 | |
555 | VLOG(1) << "processing " << PrettyPrint(kv.first); |
556 | |
557 | VLOG(2) << "count usage" ; |
558 | UsageVisitor usage_visitor(&var_to_purity, /*default_purity=*/ignore_impurity); |
559 | usage_visitor.VisitExpr(function); |
560 | |
561 | // Actually eliminate/inline the let-bindings. |
562 | VLOG(2) << "eliminate" ; |
563 | EliminatorMutator eliminator_mutator(inline_once, &usage_visitor.let_bound_values_, |
564 | &usage_visitor.use_map_, &var_to_purity, |
565 | /*default_purity=*/ignore_impurity); |
566 | result->Add(kv.first, Downcast<Function>(eliminator_mutator.VisitExpr(function))); |
567 | } else { |
568 | // PrimFuncs come across unchanged. |
569 | result->Add(kv.first, kv.second); |
570 | } |
571 | } |
572 | VLOG(1) << "After:" << std::endl << PrettyPrint(result); |
573 | |
574 | return result; |
575 | }; |
576 | return tvm::transform::CreateModulePass(pass_func, /*opt_level=*/1, "DeadCodeElimination" , |
577 | {"InferType" }); |
578 | } |
579 | |
580 | TVM_REGISTER_GLOBAL("relay._transform.DeadCodeElimination" ).set_body_typed(DeadCodeElimination); |
581 | |
582 | } // namespace transform |
583 | |
584 | } // namespace relay |
585 | } // namespace tvm |
586 | |