1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * |
22 | * \file tvm/relay/_transforms/pass_utils.h |
23 | * \brief Utilities for writing passes |
24 | */ |
25 | #ifndef TVM_RELAY_TRANSFORMS_PASS_UTILS_H_ |
26 | #define TVM_RELAY_TRANSFORMS_PASS_UTILS_H_ |
27 | |
28 | #include <tvm/relay/attrs/transform.h> |
29 | #include <tvm/relay/expr.h> |
30 | #include <tvm/relay/expr_functor.h> |
31 | #include <tvm/relay/op.h> |
32 | |
33 | #include <memory> |
34 | #include <unordered_map> |
35 | #include <unordered_set> |
36 | #include <utility> |
37 | |
38 | #include "../analysis/dependency_graph.h" |
39 | #include "../op/annotation/annotation.h" |
40 | #include "../op/memory/on_device.h" |
41 | #include "./let_list.h" |
42 | |
43 | namespace tvm { |
44 | namespace relay { |
45 | |
46 | /*! |
47 | * \brief Check if expr is positive constant. |
48 | * \param expr The expression to be checked. |
49 | * \return Whether all elements of expr is positive constant. |
50 | */ |
51 | bool IsAllPositiveConstant(const Expr& expr); |
52 | |
53 | /*! |
54 | * \brief Substitute var with subst. |
55 | * \param type The type to be substituted. |
56 | * \param tvar The type variable to be substituted. |
57 | * \param subst The target of substitution. |
58 | * \return The substituted result. |
59 | */ |
60 | Type TypeSubst(const Type& type, const TypeVar& tvar, const Type& subst); |
61 | |
62 | /*! |
63 | * \brief Substitute var with subst. |
64 | * \param expr The expr to be substituted. |
65 | * \param tvar The type variable to be substituted. |
66 | * \param subst The target of substitution. |
67 | * \return The substituted result. |
68 | */ |
69 | Expr TypeSubst(const Expr& expr, const TypeVar& tvar, const Type& subst); |
70 | |
71 | /*! |
72 | * \brief Substitute type vars in type. |
73 | * \param type The type to be substituted. |
74 | * \param subst_map The map of substitution. |
75 | * \return The substituted result. |
76 | */ |
77 | Type TypeSubst(const Type& type, const tvm::Map<TypeVar, Type>& subst_map); |
78 | |
79 | /*! |
80 | * \brief Substitute type vars in type. |
81 | * \param expr The expr to be substituted. |
82 | * \param subst_map The map of substitution. |
83 | * \return The substituted result. |
84 | */ |
85 | Expr TypeSubst(const Expr& expr, const tvm::Map<TypeVar, Type>& subst_map); |
86 | |
87 | /*! |
88 | * \brief Check if type is dynamic. |
89 | * \param ty The type to be checked. |
90 | * \return Whether the type is dynamic. |
91 | */ |
92 | bool IsDynamic(const Type& ty); |
93 | |
94 | /*! |
95 | * \brief Check if call is data dependent. |
96 | * \param call The call to be checked. |
97 | * \return Whether the call is data dependent. |
98 | */ |
99 | bool IsDataDependent(const CallNode* call); |
100 | |
101 | /*! |
102 | * \brief Make arbitrary transformation preserve the out most function. |
103 | * \param func The transformation. |
104 | * \param e The expression |
105 | * \return the transformed expression. If e is a function the return is also a function. |
106 | */ |
107 | inline Expr TransformF(const std::function<Expr(const Expr&)>& func, const Expr& e) { |
108 | if (const FunctionNode* f = e.as<FunctionNode>()) { |
109 | return WithFields(GetRef<Function>(f), f->params, func(f->body)); |
110 | } else { |
111 | return func(e); |
112 | } |
113 | } |
114 | |
115 | /*! |
116 | * \brief Decide whether the expression atomic or not? |
117 | * \param e the expression |
118 | * \return |
119 | * is it atomic? |
120 | * if so, the compute cost of the expression is bounded so it can be copy without graph mode. |
121 | */ |
122 | inline bool IsAtomic(const Expr& expr) { |
123 | Expr true_expr = IgnoreOnDevice(expr); |
124 | return true_expr.as<VarNode>() || true_expr.as<OpNode>() || true_expr.as<ConstructorNode>() || |
125 | true_expr.as<GlobalVarNode>() || |
126 | true_expr.as<ConstantNode>(); // Constant is always by reference. |
127 | } |
128 | |
129 | /*! |
130 | * \brief Cache the compiler_begin annotation op to reduce registry lookup overhead |
131 | * \param void |
132 | * \return compiler_begin op |
133 | */ |
134 | inline const Op& CompilerBeginOp() { |
135 | static auto op = Op::Get("annotation.compiler_begin" ); |
136 | return op; |
137 | } |
138 | |
139 | /*! |
140 | * \brief Cache the compiler_end annotation op to reduce registry lookup overhead |
141 | * \param void |
142 | * \return compiler_end op |
143 | */ |
144 | inline const Op& CompilerEndOp() { |
145 | static auto op = Op::Get("annotation.compiler_end" ); |
146 | return op; |
147 | } |
148 | |
149 | template <typename ConditionObjectPtr> |
150 | struct TreeNode { |
151 | typedef std::shared_ptr<TreeNode<ConditionObjectPtr>> pointer; |
152 | virtual ~TreeNode() {} |
153 | }; |
154 | |
155 | template <typename ConditionObjectPtr> |
156 | struct TreeLeafNode : TreeNode<ConditionObjectPtr> { |
157 | using TreeObjectPtr = typename TreeNode<ConditionObjectPtr>::pointer; |
158 | |
159 | Expr body; |
160 | |
161 | explicit TreeLeafNode(Expr body) : body(body) {} |
162 | |
163 | static TreeObjectPtr Make(Expr body) { return std::make_shared<TreeLeafNode>(body); } |
164 | |
165 | ~TreeLeafNode() {} |
166 | }; |
167 | |
168 | template <typename ConditionObjectPtr> |
169 | struct TreeLeafFatalNode : TreeNode<ConditionObjectPtr> { |
170 | using TreeObjectPtr = typename TreeNode<ConditionObjectPtr>::pointer; |
171 | |
172 | TreeLeafFatalNode() = default; |
173 | |
174 | static TreeObjectPtr Make() { return std::make_shared<TreeLeafFatalNode>(); } |
175 | |
176 | ~TreeLeafFatalNode() {} |
177 | }; |
178 | |
179 | template <typename ConditionObjectPtr> |
180 | struct TreeBranchNode : TreeNode<ConditionObjectPtr> { |
181 | using TreeObjectPtr = typename TreeNode<ConditionObjectPtr>::pointer; |
182 | |
183 | ConditionObjectPtr cond; |
184 | TreeObjectPtr then_branch; |
185 | TreeObjectPtr else_branch; |
186 | |
187 | TreeBranchNode(ConditionObjectPtr cond, TreeObjectPtr then_branch, TreeObjectPtr else_branch) |
188 | : cond(cond), then_branch(then_branch), else_branch(else_branch) {} |
189 | |
190 | static TreeObjectPtr Make(ConditionObjectPtr cond, TreeObjectPtr then_branch, |
191 | TreeObjectPtr else_branch) { |
192 | return std::make_shared<TreeBranchNode>(cond, then_branch, else_branch); |
193 | } |
194 | |
195 | ~TreeBranchNode() {} |
196 | }; |
197 | |
198 | struct ScopeNode; |
199 | using Scope = std::shared_ptr<ScopeNode>; |
200 | using NodeScopeMap = std::unordered_map<DependencyGraph::Node*, Scope>; |
201 | using ExprSet = std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual>; |
202 | |
203 | /* Invariant: when parent is null level is 0 |
204 | * Invariant: when parent is not null level is 1 + parent->level |
205 | */ |
206 | struct ScopeNode { |
207 | // the level of the scope |
208 | size_t level; |
209 | // the parent scope |
210 | Scope parent; |
211 | // the corresponding let list which holds all let bindings in the scope |
212 | std::shared_ptr<LetList> let_list = std::make_shared<LetList>(); |
213 | explicit ScopeNode(const Scope& parent) : level(1 + parent->level), parent(parent) {} |
214 | ScopeNode() : level(0) {} |
215 | }; |
216 | |
217 | /*! \brief Calculate the scope of nodes in the dependency graph by least common ancestor. |
218 | * |
219 | * \param dg the input dependency graph |
220 | * \param expr_scope the output node -> scope mapping for all nodes. |
221 | * \param lifted_exprs the output set of expressions whose scope is lifted due to dependency |
222 | */ |
223 | std::pair<NodeScopeMap, ExprSet> CalcScope(const DependencyGraph& dg); |
224 | |
225 | /*! \brief find the least common ancestor of lhs scope and rhs scope. |
226 | */ |
227 | Scope LCA(Scope lhs, Scope rhs); |
228 | |
229 | // For basic block normal form. |
230 | Expr ToBasicBlockNormalFormAux(const Expr& e); |
231 | |
232 | // ToANormalForm for expressions and as a Pass are declared in transform.h |
233 | |
234 | } // namespace relay |
235 | } // namespace tvm |
236 | #endif // TVM_RELAY_TRANSFORMS_PASS_UTILS_H_ |
237 | |