1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/relay/expr_functor.h |
22 | * \brief A more powerful visitor which enables defining arbitrary function |
23 | * signatures with type based dispatch on first argument. |
24 | */ |
25 | #ifndef TVM_RELAY_EXPR_FUNCTOR_H_ |
26 | #define TVM_RELAY_EXPR_FUNCTOR_H_ |
27 | |
28 | #include <tvm/node/functor.h> |
29 | #include <tvm/relay/adt.h> |
30 | #include <tvm/relay/error.h> |
31 | #include <tvm/relay/expr.h> |
32 | #include <tvm/relay/function.h> |
33 | #include <tvm/relay/op.h> |
34 | |
35 | #include <deque> |
36 | #include <string> |
37 | #include <unordered_map> |
38 | #include <utility> |
39 | #include <vector> |
40 | |
41 | namespace tvm { |
42 | namespace relay { |
43 | |
44 | /*! |
45 | * \brief A dynamical functor that dispatches on in the first Expr argument. |
46 | * You can use this as a more powerful Visitor, since it allows you to |
47 | * define function signatures of Visit Function. |
48 | * |
49 | * \sa tvm/ir_functor.h |
50 | * |
51 | * \tparam FType function signiture |
52 | * This type is only defined for FType with function signature R(const Expr&, |
53 | * Args...) |
54 | */ |
55 | template <typename FType> |
56 | class ExprFunctor; |
57 | |
58 | // functions to be overriden. |
59 | #define EXPR_FUNCTOR_DEFAULT \ |
60 | { return VisitExprDefault_(op, std::forward<Args>(args)...); } |
61 | |
62 | #define RELAY_EXPR_FUNCTOR_DISPATCH(OP) \ |
63 | vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \ |
64 | return self->VisitExpr_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \ |
65 | }); |
66 | |
67 | template <typename R, typename... Args> |
68 | class ExprFunctor<R(const Expr& n, Args...)> { |
69 | private: |
70 | using TSelf = ExprFunctor<R(const Expr& n, Args...)>; |
71 | using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>; |
72 | |
73 | public: |
74 | /*! \brief the result type of this functor */ |
75 | using result_type = R; |
76 | /*! \brief virtual destructor */ |
77 | virtual ~ExprFunctor() {} |
78 | /*! |
79 | * \brief Same as call. |
80 | * \param n The expression node. |
81 | * \param args Additional arguments. |
82 | * \return The result of the call |
83 | */ |
84 | R operator()(const Expr& n, Args... args) { return VisitExpr(n, std::forward<Args>(args)...); } |
85 | /*! |
86 | * \brief The functor call. |
87 | * \param n The expression node. |
88 | * \param args Additional arguments. |
89 | * \return The result of the call |
90 | */ |
91 | virtual R VisitExpr(const Expr& n, Args... args) { |
92 | ICHECK(n.defined()) << "Found null pointer node while traversing AST. The previous pass may " |
93 | "have generated invalid data." ; |
94 | static FType vtable = InitVTable(); |
95 | return vtable(n, this, std::forward<Args>(args)...); |
96 | } |
97 | // Functions that can be overriden by subclass |
98 | virtual R VisitExpr_(const ConstantNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
99 | virtual R VisitExpr_(const TupleNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
100 | virtual R VisitExpr_(const VarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
101 | virtual R VisitExpr_(const GlobalVarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
102 | virtual R VisitExpr_(const FunctionNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
103 | virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
104 | virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
105 | virtual R VisitExpr_(const IfNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
106 | virtual R VisitExpr_(const OpNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
107 | virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
108 | virtual R VisitExpr_(const RefCreateNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
109 | virtual R VisitExpr_(const RefReadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
110 | virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
111 | virtual R VisitExpr_(const ConstructorNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
112 | virtual R VisitExpr_(const MatchNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
113 | virtual R VisitExprDefault_(const Object* op, Args...) { |
114 | LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); |
115 | throw; |
116 | } |
117 | |
118 | private: |
119 | // initialize the vtable. |
120 | static FType InitVTable() { |
121 | FType vtable; |
122 | // Set dispatch |
123 | RELAY_EXPR_FUNCTOR_DISPATCH(ConstantNode); |
124 | RELAY_EXPR_FUNCTOR_DISPATCH(TupleNode); |
125 | RELAY_EXPR_FUNCTOR_DISPATCH(VarNode); |
126 | RELAY_EXPR_FUNCTOR_DISPATCH(GlobalVarNode); |
127 | RELAY_EXPR_FUNCTOR_DISPATCH(FunctionNode); |
128 | RELAY_EXPR_FUNCTOR_DISPATCH(CallNode); |
129 | RELAY_EXPR_FUNCTOR_DISPATCH(LetNode); |
130 | RELAY_EXPR_FUNCTOR_DISPATCH(IfNode); |
131 | RELAY_EXPR_FUNCTOR_DISPATCH(OpNode); |
132 | RELAY_EXPR_FUNCTOR_DISPATCH(TupleGetItemNode); |
133 | RELAY_EXPR_FUNCTOR_DISPATCH(RefCreateNode); |
134 | RELAY_EXPR_FUNCTOR_DISPATCH(RefReadNode); |
135 | RELAY_EXPR_FUNCTOR_DISPATCH(RefWriteNode); |
136 | RELAY_EXPR_FUNCTOR_DISPATCH(ConstructorNode); |
137 | RELAY_EXPR_FUNCTOR_DISPATCH(MatchNode); |
138 | return vtable; |
139 | } |
140 | }; |
141 | |
142 | /*! |
143 | * \brief A simple visitor wrapper around ExprFunctor. |
144 | * Recursively visit the content. |
145 | * |
146 | * ExprVisitor treats Expr as dataflow graph, |
147 | * and only visit each Expr node once. |
148 | */ |
149 | class ExprVisitor : public ::tvm::relay::ExprFunctor<void(const Expr& n)> { |
150 | public: |
151 | void VisitExpr(const Expr& expr) override; |
152 | void VisitExpr_(const VarNode* op) override; |
153 | void VisitExpr_(const GlobalVarNode* op) override; |
154 | void VisitExpr_(const ConstantNode* op) override; |
155 | void VisitExpr_(const TupleNode* op) override; |
156 | void VisitExpr_(const FunctionNode* op) override; |
157 | void VisitExpr_(const CallNode* op) override; |
158 | void VisitExpr_(const LetNode* op) override; |
159 | void VisitExpr_(const IfNode* op) override; |
160 | void VisitExpr_(const OpNode* op) override; |
161 | void VisitExpr_(const TupleGetItemNode* op) override; |
162 | void VisitExpr_(const RefCreateNode* op) override; |
163 | void VisitExpr_(const RefReadNode* op) override; |
164 | void VisitExpr_(const RefWriteNode* op) override; |
165 | void VisitExpr_(const ConstructorNode* op) override; |
166 | void VisitExpr_(const MatchNode* op) override; |
167 | virtual void VisitType(const Type& t); |
168 | virtual void VisitClause(const Clause& c); |
169 | virtual void VisitPattern(const Pattern& c); |
170 | virtual void VisitSpan(const Span& span); |
171 | |
172 | protected: |
173 | // Internal visiting counter |
174 | std::unordered_map<const Object*, size_t> visit_counter_; |
175 | }; |
176 | |
177 | /*! |
178 | * \brief A wrapper around ExprFunctor which functionally updates the AST. |
179 | * |
180 | * ExprMutator treats Expr as dataflow graph, and only Mutate each Expr once. |
181 | * The mutated results are memoized in a map and reused so that |
182 | * local transformation on the dataflow preserves the graph structure. |
183 | */ |
184 | class ExprMutator : public ::tvm::relay::ExprFunctor<Expr(const Expr&)> { |
185 | public: |
186 | /*! |
187 | * \brief Mutate is alias for VisitExpr |
188 | * \return expr. |
189 | */ |
190 | Expr Mutate(const Expr& expr) { return this->VisitExpr(expr); } |
191 | Expr VisitExpr(const Expr& expr) override; |
192 | Expr VisitExpr_(const VarNode* op) override; |
193 | Expr VisitExpr_(const ConstantNode* op) override; |
194 | Expr VisitExpr_(const GlobalVarNode* op) override; |
195 | Expr VisitExpr_(const OpNode* op) override; |
196 | Expr VisitExpr_(const TupleNode* op) override; |
197 | Expr VisitExpr_(const FunctionNode* op) override; |
198 | Expr VisitExpr_(const CallNode* call_node) override; |
199 | Expr VisitExpr_(const LetNode* op) override; |
200 | Expr VisitExpr_(const IfNode* op) override; |
201 | Expr VisitExpr_(const TupleGetItemNode* op) override; |
202 | Expr VisitExpr_(const RefCreateNode* op) override; |
203 | Expr VisitExpr_(const RefReadNode* op) override; |
204 | Expr VisitExpr_(const RefWriteNode* op) override; |
205 | Expr VisitExpr_(const ConstructorNode* op) override; |
206 | Expr VisitExpr_(const MatchNode* op) override; |
207 | |
208 | /*! |
209 | * \brief Used to visit the types inside of expressions. |
210 | * |
211 | * Can be overloaded to transform the types in arbitrary |
212 | * ways, one way would be to define a sub-class of type |
213 | * visitor for types which transform them appropriately. |
214 | */ |
215 | virtual Type VisitType(const Type& t); |
216 | virtual Clause VisitClause(const Clause& c); |
217 | virtual Pattern VisitPattern(const Pattern& c); |
218 | |
219 | protected: |
220 | /*! \brief Internal map used for memoization. */ |
221 | std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual> memo_; |
222 | }; |
223 | |
224 | /*! |
225 | * \brief A wrapper around ExprVisitor which traverses the Dataflow Normal AST. |
226 | * |
227 | * MixedModeVisitor treats Expr as dataflow graph, and visits in post-DFS order |
228 | * |
229 | * MixedModeVisitor provides the same recursive API as ExprVisitor, and uses |
230 | * recursion to traverse most forms of the IR, but under the hood it expands nested dataflow regions |
231 | * of the graph and processes them iteratively to prevent stack overflows |
232 | */ |
233 | class MixedModeVisitor : public ::tvm::relay::ExprVisitor { |
234 | public: |
235 | using ::tvm::relay::ExprFunctor<void(const Expr& n)>::VisitExpr_; |
236 | |
237 | /*! \brief The constructor of MixedModeVisitor |
238 | * \param visit_limit The number of times to allow visitation to a node. Usually 1, ocassionally |
239 | * higher (i.e., 2 for dead code elimiation), limited to 10 as a sanity check. |
240 | */ |
241 | explicit MixedModeVisitor(int visit_limit = 1); |
242 | |
243 | using ExprVisitor::VisitExpr_; |
244 | |
245 | /*! |
246 | * \brief VisitExpr is finalized to preserve call expansion of dataflow regions |
247 | */ |
248 | void VisitExpr(const Expr& expr) final; |
249 | void VisitExpr_(const CallNode* op) override; |
250 | void VisitExpr_(const TupleNode* op) override; |
251 | void VisitExpr_(const TupleGetItemNode* op) override; |
252 | |
253 | protected: |
254 | /*! |
255 | * \brief A function to apply when reaching a leaf of the graph non-recursively |
256 | */ |
257 | virtual void VisitLeaf(const Expr& expr); |
258 | /*! |
259 | * \brief A function to determine if an expression has already been visited or needs to be |
260 | * re-visited |
261 | */ |
262 | virtual bool CheckVisited(const Expr& expr); |
263 | /*! |
264 | * \brief The max number of times to visit a node |
265 | */ |
266 | size_t visit_limit_; |
267 | }; |
268 | |
269 | /*! \brief Non-recursive DFS Graph Traversal for Custom Rewriting Passes |
270 | * |
271 | * MixedModeMutator treats Expr as dataflow graph, and only Rewrites each Expr once. |
272 | * The mutated results are memoized in a map and reused so that |
273 | * local transformation on the dataflow preserves the graph structure. |
274 | * |
275 | * MixedModeMutator provides the same recursive API as ExprMutator, and uses |
276 | * recursion to traverse most forms of the IR, but under the hood it expands nested dataflow regions |
277 | * of the graph and processes them iteratatively to prevent stack overflows |
278 | * |
279 | * Uses Rewrite_ API of ExprRewriter for a cleaner split between recrusive and non-recursive |
280 | * behavior. |
281 | */ |
282 | class MixedModeMutator : public ::tvm::relay::ExprMutator { |
283 | public: |
284 | using ::tvm::relay::ExprFunctor<Expr(const Expr&)>::VisitExpr_; |
285 | |
286 | MixedModeMutator(bool pre = false) : pre_{pre} {}; |
287 | Expr VisitExpr(const Expr& expr) final; |
288 | |
289 | virtual Expr DispatchVisitExpr(const Expr& expr); |
290 | Expr VisitExpr_(const TupleNode* op) final { return Rewrite(op); }; |
291 | Expr VisitExpr_(const CallNode* call_node) final { return Rewrite(call_node); }; |
292 | Expr VisitExpr_(const TupleGetItemNode* op) final { return Rewrite(op); }; |
293 | /*! |
294 | * \brief Users should override Rewrite_ methods to implement their pass. Rewrite_ functions will |
295 | * be able to rewrite the op only with data about the original node `pre` and the same node with |
296 | * modified inputs `post` and should not recurse. |
297 | * |
298 | * \param pre The expression node before rewriting. |
299 | * \param post The expression with rewritten inputs. |
300 | */ |
301 | virtual Expr Rewrite_(const TupleNode* pre, const Expr& post) { return post; } |
302 | virtual Expr Rewrite_(const CallNode* pre, const Expr& post) { return post; } |
303 | virtual Expr Rewrite_(const TupleGetItemNode* pre, const Expr& post) { return post; } |
304 | |
305 | protected: |
306 | bool pre_; |
307 | /*! \brief Implement Rewrite API by calling ExprMutator's VisitExpr_(op) to get a `post` node with |
308 | * changed inputs. |
309 | */ |
310 | template <typename T> |
311 | Expr Rewrite(const T* op) { |
312 | Expr post = ExprMutator::VisitExpr_(op); |
313 | return Rewrite_(op, post); |
314 | } |
315 | |
316 | virtual void VisitLeaf(const Expr& expr); |
317 | virtual bool CheckVisited(const Expr& expr); |
318 | }; |
319 | |
320 | #define RELAY_EXPR_REWRITER_DISPATCH(OP) \ |
321 | vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, const Expr& post) { \ |
322 | return self->Rewrite_(static_cast<const OP*>(n.get()), post); \ |
323 | }); |
324 | |
325 | #define EXPR_REWRITER_REWRITE_DEFAULT \ |
326 | { return post; } |
327 | |
328 | /*! \brief A non-iterating Expression Rewriter |
329 | * |
330 | * ExprRewriter provides a Rewrite interface for modifying graphs in Post-DFS order. |
331 | * |
332 | * The expectation is that ExprRewriter objects will be passed to PostOrderRewrite, which will |
333 | * non-recursively unroll the graph and call Rewriting on inputs. It will then pass the original |
334 | * node, called `pre`, and a node recreated with any alterned inputs, called `post`, to the |
335 | * ExprRewriter. The ExprRewriter can then use the information in those two nodes to do more complex |
336 | * graph rewriting. |
337 | */ |
338 | class ExprRewriter { |
339 | private: |
340 | using TSelf = ExprRewriter; |
341 | using FType = tvm::NodeFunctor<Expr(const ObjectRef& n, TSelf* self, const Expr& post)>; |
342 | |
343 | public: |
344 | /*! \brief virtual destructor */ |
345 | virtual ~ExprRewriter() {} |
346 | /*! |
347 | * \brief Same as call. |
348 | * \param pre The expression node before rewriting. |
349 | * \param post The expression node with rewritten inputs. |
350 | * \return The result of the call |
351 | */ |
352 | Expr operator()(const Expr& pre, const Expr& post) { return Rewrite(pre, post); } |
353 | /*! |
354 | * \brief The functor call. |
355 | * \param pre The expression node before rewriting. |
356 | * \param post The expression node with rewritten inputs. |
357 | * \return The result of the call |
358 | */ |
359 | virtual Expr Rewrite(const Expr& pre, const Expr& post) { |
360 | ICHECK(pre.defined()); |
361 | static FType vtable = InitVTable(); |
362 | return vtable(pre, this, post); |
363 | } |
364 | // Functions that can be overriden by subclass, should not recurse |
365 | virtual Expr Rewrite_(const VarNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT; |
366 | virtual Expr Rewrite_(const GlobalVarNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT; |
367 | virtual Expr Rewrite_(const ConstantNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT; |
368 | virtual Expr Rewrite_(const TupleNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT; |
369 | virtual Expr Rewrite_(const FunctionNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT; |
370 | virtual Expr Rewrite_(const CallNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT; |
371 | virtual Expr Rewrite_(const LetNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT; |
372 | virtual Expr Rewrite_(const IfNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT; |
373 | virtual Expr Rewrite_(const OpNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT; |
374 | virtual Expr Rewrite_(const TupleGetItemNode* pre, |
375 | const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT; |
376 | virtual Expr Rewrite_(const RefCreateNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT; |
377 | virtual Expr Rewrite_(const RefReadNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT; |
378 | virtual Expr Rewrite_(const RefWriteNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT; |
379 | virtual Expr Rewrite_(const ConstructorNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT; |
380 | virtual Expr Rewrite_(const MatchNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT; |
381 | |
382 | private: |
383 | // initialize the vtable. |
384 | static FType InitVTable() { |
385 | FType vtable; |
386 | // Set dispatch |
387 | RELAY_EXPR_REWRITER_DISPATCH(ConstantNode); |
388 | RELAY_EXPR_REWRITER_DISPATCH(TupleNode); |
389 | RELAY_EXPR_REWRITER_DISPATCH(VarNode); |
390 | RELAY_EXPR_REWRITER_DISPATCH(GlobalVarNode); |
391 | RELAY_EXPR_REWRITER_DISPATCH(FunctionNode); |
392 | RELAY_EXPR_REWRITER_DISPATCH(CallNode); |
393 | RELAY_EXPR_REWRITER_DISPATCH(LetNode); |
394 | RELAY_EXPR_REWRITER_DISPATCH(IfNode); |
395 | RELAY_EXPR_REWRITER_DISPATCH(OpNode); |
396 | RELAY_EXPR_REWRITER_DISPATCH(TupleGetItemNode); |
397 | RELAY_EXPR_REWRITER_DISPATCH(RefCreateNode); |
398 | RELAY_EXPR_REWRITER_DISPATCH(RefReadNode); |
399 | RELAY_EXPR_REWRITER_DISPATCH(RefWriteNode); |
400 | RELAY_EXPR_REWRITER_DISPATCH(ConstructorNode); |
401 | RELAY_EXPR_REWRITER_DISPATCH(MatchNode); |
402 | return vtable; |
403 | } |
404 | }; |
405 | |
406 | /*! \brief Non-recursive DFS Graph Traversal for Custom Rewriting Passes |
407 | * |
408 | * PostOrderRewrite does a non-recursive traversal of the graph in Post-DFS order and calls the |
409 | * ExprRewriter's Rewrite functions on nodes once their inputs are rewritten. At each rewrite call, |
410 | * PostOrderRewrite provides the original node and the node with altered inputs for use by the |
411 | * ExprRewriter. |
412 | */ |
413 | Expr PostOrderRewrite(const Expr& expr, ExprRewriter* rewriter); |
414 | |
415 | /*! |
416 | * \brief recursively visit the ir in post DFS order node, apply fvisit |
417 | * Each node is guaranteed to be visited only once. |
418 | * \param node The ir to be visited. |
419 | * \param fvisit The visitor function to be applied. |
420 | */ |
421 | void PostOrderVisit(const Expr& node, std::function<void(const Expr&)> fvisit); |
422 | |
423 | /*! |
424 | * \brief A struct to keep info of traversed expr in ExpandDataflow function |
425 | */ |
426 | struct v_info { |
427 | explicit v_info(Expr node_) : node{node_} {} |
428 | v_info(Expr node_, bool children_expanded_) |
429 | : node{node_}, children_expanded{children_expanded_} {}; |
430 | Expr node{}; |
431 | bool children_expanded{false}; |
432 | }; |
433 | |
434 | /*! |
435 | * \brief A function to iteratively traverse dataflow regions of a graph |
436 | * |
437 | * ExpandDataflow manually manages a stack and performs DFS to determine the processing |
438 | * order of nodes in an input graph. |
439 | * |
440 | * By default fexpand_expr implemented in a way that if it finds a dataflow node (Call, Tuple, |
441 | * TupleGetItem), it checks if the arguments to that node need to be processed via fcheck_visited. |
442 | * If so, the function pushes those arguments to the stack and continues iteratively to process |
443 | * the top of the stack. When it finds a node that doesn't match the dataflow types, or a node who's |
444 | * inputs have all been processed, it visits the current leaf via fvisit_leaf. |
445 | * |
446 | * This function should be used internally to other classes to implement mixed-mode traversals. The |
447 | * expectation is that fvisit_leaf will perform recursive analysis within mixed-mode traversal if it |
448 | * hits a non-dataflow node. |
449 | * |
450 | * fcheck_visited, fvisit_leaf and fexpand_expr are templated to encourage reusing. |
451 | */ |
452 | template <typename FCheckVisited, typename FVisitLeaf, typename FExpandExpr> |
453 | void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf, |
454 | FExpandExpr fexpand_expr) { |
455 | std::deque<v_info> stack; |
456 | auto fpush_to_stack = [&fcheck_visited, &stack](const Expr& expr) { |
457 | if (!fcheck_visited(expr)) { |
458 | stack.emplace_front(v_info(expr)); |
459 | } |
460 | }; |
461 | |
462 | fpush_to_stack(expr); |
463 | while (stack.size() > 0) { |
464 | v_info* front = &stack.front(); |
465 | if (fcheck_visited(front->node)) { |
466 | stack.pop_front(); |
467 | } else if (front->children_expanded) { |
468 | fvisit_leaf(front->node); |
469 | // TODO(d-smirnov): this is for compatibility with current implementation of MixedModeVisitor |
470 | stack.pop_front(); |
471 | } else { |
472 | front->children_expanded = true; |
473 | for (auto e : fexpand_expr(front->node)) { |
474 | fpush_to_stack(e); |
475 | } |
476 | } |
477 | } |
478 | } |
479 | |
480 | template <typename FCheckVisited, typename FVisitLeaf> |
481 | void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf) { |
482 | auto fexpand_expr = [](const Expr& expr) { |
483 | std::vector<Expr> result; |
484 | if (const CallNode* op = expr.as<CallNode>()) { |
485 | if (op->op == Op::Get("call_lowered" )) { |
486 | // Ignore the intermediate tuple since this is purely a calling-convention detail |
487 | const auto* tuple_args = op->args[1].as<TupleNode>(); |
488 | ICHECK(tuple_args) |
489 | << "Expected second arg to call_lowered to be a Tuple of input arguments." ; |
490 | for (auto it = tuple_args->fields.rbegin(); it != tuple_args->fields.rend(); ++it) { |
491 | result.push_back(*it); |
492 | } |
493 | result.push_back(op->args[0]); |
494 | } else { |
495 | for (auto it = op->args.rbegin(); it != op->args.rend(); ++it) { |
496 | result.push_back(*it); |
497 | } |
498 | } |
499 | result.push_back(op->op); |
500 | } else if (const TupleNode* op = expr.as<TupleNode>()) { |
501 | for (auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) { |
502 | result.push_back(*it); |
503 | } |
504 | } else if (const TupleGetItemNode* op = expr.as<TupleGetItemNode>()) { |
505 | result.push_back(op->tuple); |
506 | } |
507 | return result; |
508 | }; |
509 | ExpandDataflow(expr, fcheck_visited, fvisit_leaf, fexpand_expr); |
510 | } |
511 | |
512 | void ExpandANormalForm(const LetNode* op, std::function<void(const LetNode*)> pre_visit, |
513 | std::function<void(const LetNode*)> post_visit); |
514 | |
515 | } // namespace relay |
516 | } // namespace tvm |
517 | #endif // TVM_RELAY_EXPR_FUNCTOR_H_ |
518 | |