1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 *
22 * \file tvm/relay/_transforms/pass_utils.h
23 * \brief Utilities for writing passes
24 */
25#ifndef TVM_RELAY_TRANSFORMS_PASS_UTILS_H_
26#define TVM_RELAY_TRANSFORMS_PASS_UTILS_H_
27
28#include <tvm/relay/attrs/transform.h>
29#include <tvm/relay/expr.h>
30#include <tvm/relay/expr_functor.h>
31#include <tvm/relay/op.h>
32
33#include <memory>
34#include <unordered_map>
35#include <unordered_set>
36#include <utility>
37
38#include "../analysis/dependency_graph.h"
39#include "../op/annotation/annotation.h"
40#include "../op/memory/on_device.h"
41#include "./let_list.h"
42
43namespace tvm {
44namespace relay {
45
46/*!
47 * \brief Check if expr is positive constant.
48 * \param expr The expression to be checked.
49 * \return Whether all elements of expr is positive constant.
50 */
51bool IsAllPositiveConstant(const Expr& expr);
52
53/*!
54 * \brief Substitute var with subst.
55 * \param type The type to be substituted.
56 * \param tvar The type variable to be substituted.
57 * \param subst The target of substitution.
58 * \return The substituted result.
59 */
60Type TypeSubst(const Type& type, const TypeVar& tvar, const Type& subst);
61
62/*!
63 * \brief Substitute var with subst.
64 * \param expr The expr to be substituted.
65 * \param tvar The type variable to be substituted.
66 * \param subst The target of substitution.
67 * \return The substituted result.
68 */
69Expr TypeSubst(const Expr& expr, const TypeVar& tvar, const Type& subst);
70
71/*!
72 * \brief Substitute type vars in type.
73 * \param type The type to be substituted.
74 * \param subst_map The map of substitution.
75 * \return The substituted result.
76 */
77Type TypeSubst(const Type& type, const tvm::Map<TypeVar, Type>& subst_map);
78
79/*!
80 * \brief Substitute type vars in type.
81 * \param expr The expr to be substituted.
82 * \param subst_map The map of substitution.
83 * \return The substituted result.
84 */
85Expr TypeSubst(const Expr& expr, const tvm::Map<TypeVar, Type>& subst_map);
86
87/*!
88 * \brief Check if type is dynamic.
89 * \param ty The type to be checked.
90 * \return Whether the type is dynamic.
91 */
92bool IsDynamic(const Type& ty);
93
94/*!
95 * \brief Check if call is data dependent.
96 * \param call The call to be checked.
97 * \return Whether the call is data dependent.
98 */
99bool IsDataDependent(const CallNode* call);
100
101/*!
102 * \brief Make arbitrary transformation preserve the out most function.
103 * \param func The transformation.
104 * \param e The expression
105 * \return the transformed expression. If e is a function the return is also a function.
106 */
107inline Expr TransformF(const std::function<Expr(const Expr&)>& func, const Expr& e) {
108 if (const FunctionNode* f = e.as<FunctionNode>()) {
109 return WithFields(GetRef<Function>(f), f->params, func(f->body));
110 } else {
111 return func(e);
112 }
113}
114
115/*!
116 * \brief Decide whether the expression atomic or not?
117 * \param e the expression
118 * \return
119 * is it atomic?
120 * if so, the compute cost of the expression is bounded so it can be copy without graph mode.
121 */
122inline bool IsAtomic(const Expr& expr) {
123 Expr true_expr = IgnoreOnDevice(expr);
124 return true_expr.as<VarNode>() || true_expr.as<OpNode>() || true_expr.as<ConstructorNode>() ||
125 true_expr.as<GlobalVarNode>() ||
126 true_expr.as<ConstantNode>(); // Constant is always by reference.
127}
128
129/*!
130 * \brief Cache the compiler_begin annotation op to reduce registry lookup overhead
131 * \param void
132 * \return compiler_begin op
133 */
134inline const Op& CompilerBeginOp() {
135 static auto op = Op::Get("annotation.compiler_begin");
136 return op;
137}
138
139/*!
140 * \brief Cache the compiler_end annotation op to reduce registry lookup overhead
141 * \param void
142 * \return compiler_end op
143 */
144inline const Op& CompilerEndOp() {
145 static auto op = Op::Get("annotation.compiler_end");
146 return op;
147}
148
149template <typename ConditionObjectPtr>
150struct TreeNode {
151 typedef std::shared_ptr<TreeNode<ConditionObjectPtr>> pointer;
152 virtual ~TreeNode() {}
153};
154
155template <typename ConditionObjectPtr>
156struct TreeLeafNode : TreeNode<ConditionObjectPtr> {
157 using TreeObjectPtr = typename TreeNode<ConditionObjectPtr>::pointer;
158
159 Expr body;
160
161 explicit TreeLeafNode(Expr body) : body(body) {}
162
163 static TreeObjectPtr Make(Expr body) { return std::make_shared<TreeLeafNode>(body); }
164
165 ~TreeLeafNode() {}
166};
167
168template <typename ConditionObjectPtr>
169struct TreeLeafFatalNode : TreeNode<ConditionObjectPtr> {
170 using TreeObjectPtr = typename TreeNode<ConditionObjectPtr>::pointer;
171
172 TreeLeafFatalNode() = default;
173
174 static TreeObjectPtr Make() { return std::make_shared<TreeLeafFatalNode>(); }
175
176 ~TreeLeafFatalNode() {}
177};
178
179template <typename ConditionObjectPtr>
180struct TreeBranchNode : TreeNode<ConditionObjectPtr> {
181 using TreeObjectPtr = typename TreeNode<ConditionObjectPtr>::pointer;
182
183 ConditionObjectPtr cond;
184 TreeObjectPtr then_branch;
185 TreeObjectPtr else_branch;
186
187 TreeBranchNode(ConditionObjectPtr cond, TreeObjectPtr then_branch, TreeObjectPtr else_branch)
188 : cond(cond), then_branch(then_branch), else_branch(else_branch) {}
189
190 static TreeObjectPtr Make(ConditionObjectPtr cond, TreeObjectPtr then_branch,
191 TreeObjectPtr else_branch) {
192 return std::make_shared<TreeBranchNode>(cond, then_branch, else_branch);
193 }
194
195 ~TreeBranchNode() {}
196};
197
198struct ScopeNode;
199using Scope = std::shared_ptr<ScopeNode>;
200using NodeScopeMap = std::unordered_map<DependencyGraph::Node*, Scope>;
201using ExprSet = std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual>;
202
203/* Invariant: when parent is null level is 0
204 * Invariant: when parent is not null level is 1 + parent->level
205 */
206struct ScopeNode {
207 // the level of the scope
208 size_t level;
209 // the parent scope
210 Scope parent;
211 // the corresponding let list which holds all let bindings in the scope
212 std::shared_ptr<LetList> let_list = std::make_shared<LetList>();
213 explicit ScopeNode(const Scope& parent) : level(1 + parent->level), parent(parent) {}
214 ScopeNode() : level(0) {}
215};
216
217/*! \brief Calculate the scope of nodes in the dependency graph by least common ancestor.
218 *
219 * \param dg the input dependency graph
220 * \param expr_scope the output node -> scope mapping for all nodes.
221 * \param lifted_exprs the output set of expressions whose scope is lifted due to dependency
222 */
223std::pair<NodeScopeMap, ExprSet> CalcScope(const DependencyGraph& dg);
224
225/*! \brief find the least common ancestor of lhs scope and rhs scope.
226 */
227Scope LCA(Scope lhs, Scope rhs);
228
229// For basic block normal form.
230Expr ToBasicBlockNormalFormAux(const Expr& e);
231
232// ToANormalForm for expressions and as a Pass are declared in transform.h
233
234} // namespace relay
235} // namespace tvm
236#endif // TVM_RELAY_TRANSFORMS_PASS_UTILS_H_
237