1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/tir/analysis.h |
22 | * \brief Analysis utilities and passes for TIR. |
23 | */ |
24 | #ifndef TVM_TIR_ANALYSIS_H_ |
25 | #define TVM_TIR_ANALYSIS_H_ |
26 | |
27 | #include <tvm/ir/module.h> |
28 | #include <tvm/ir/transform.h> |
29 | #include <tvm/tir/expr.h> |
30 | #include <tvm/tir/function.h> |
31 | #include <tvm/tir/op_attr_types.h> |
32 | #include <tvm/tir/stmt.h> |
33 | |
34 | #include <string> |
35 | |
36 | namespace tvm { |
37 | namespace tir { |
38 | |
39 | /*! |
40 | * \brief Compare two expressions recursively and check if they are equal |
41 | * to each other without var remapping. |
42 | * |
43 | * This function does not remap variable bindings, it will not |
44 | * return true for (let x = 1 in x + 1) vs (let y = 1 in y + 1), unless x.same_as(y). |
45 | * |
46 | * Use StructuralEqual for such cases. |
47 | * |
48 | * Due to the restriction of not remapping variables, this function can run |
49 | * faster than StructuralEqual and can be used as a utility function during arithmetic |
50 | * simplifications. |
51 | * |
52 | * \sa StructuralEqual |
53 | */ |
54 | struct ExprDeepEqual { |
55 | public: |
56 | TVM_DLL bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const; |
57 | }; |
58 | |
59 | /*! |
60 | * \brief Visit the PrimFuncs in the IRModule |
61 | * \tparam FLambda The type of the PrimFunc visitor |
62 | * \param mod The IRModule to be visited |
63 | * \param fvisit The visitor to the PrimFuncs in the IRModule |
64 | */ |
65 | template <class FLambda> |
66 | inline void VisitPrimFuncs(const IRModule& mod, FLambda fvisit) { |
67 | for (const auto& kv : mod->functions) { |
68 | const BaseFunc& base_func = kv.second; |
69 | if (const auto* prim_func = base_func.as<PrimFuncNode>()) { |
70 | fvisit(prim_func); |
71 | } |
72 | } |
73 | } |
74 | |
75 | /*! |
76 | * \brief Estimate the FLOPs of a TIR fragment. |
77 | * \param stmt The TIR fragment to be estimated. |
78 | * \return The estimated FLOPs. |
79 | */ |
80 | TVM_DLL double EstimateTIRFlops(const Stmt& stmt); |
81 | |
82 | /*! |
83 | * \brief Estimate the FLOPs of TIRs in an IRModule. |
84 | * \param mod The IRModule to be estimated. |
85 | * \return The estimated FLOPs. |
86 | */ |
87 | TVM_DLL double EstimateTIRFlops(const IRModule& mod); |
88 | |
89 | /*! |
90 | * \brief Find undefined vars in the statement. |
91 | * \param stmt The function to be checked. |
92 | * \param defs The vars that is defined. |
93 | * \return Array of undefined vars. |
94 | */ |
95 | TVM_DLL Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs); |
96 | |
97 | /*! |
98 | * \brief Find undefined vars in the expression. |
99 | * \param expr The expression to be checked. |
100 | * \return Array of undefined vars. |
101 | */ |
102 | TVM_DLL Array<Var> UndefinedVars(const PrimExpr& expr); |
103 | |
104 | /*! |
105 | * \brief Analyze the side effect |
106 | * \param expr The expression to be checked. |
107 | * |
108 | * \return CallEffectKind, can be kPure, kReadState or kUpdateState |
109 | */ |
110 | TVM_DLL CallEffectKind SideEffect(const PrimExpr& expr); |
111 | |
112 | /*! |
113 | * \brief Whether the given Stmt uses any var in the given variable set. |
114 | * \param stmt The Stmt to be checked. |
115 | * \param vset_contains The check function to see if a var is in the variable set. |
116 | * \return Whether `stmt` uses any var in the given variable set. |
117 | */ |
118 | TVM_DLL bool UsesVar(const Stmt& stmt, std::function<bool(const VarNode*)> vset_contains); |
119 | |
120 | /*! |
121 | * \brief Whether the given PrimExpr uses any var in the given variable set. |
122 | * \param expr The PrimExpr to be checked. |
123 | * \param vset_contains The check function to see if var is in the variable set. |
124 | * \return Whether `expr` uses any var in the given variable set. |
125 | */ |
126 | TVM_DLL bool UsesVar(const PrimExpr& expr, std::function<bool(const VarNode*)> vset_contains); |
127 | |
128 | /*! |
129 | * \brief Verifies whether the IR stmt or Expr is in SSA form. |
130 | * That is: each Var is defined and assigned once(in Let/For) |
131 | * |
132 | * \param func The function to be verified. |
133 | * \return Whether IR is in SSA form. |
134 | * |
135 | * \note All passes in TIR consume and produce SSA form. |
136 | */ |
137 | TVM_DLL bool VerifySSA(const PrimFunc& func); |
138 | |
139 | /*! |
140 | * \brief Verify if memory accesses are legal for a specific target device type. |
141 | * |
142 | * In the case that tgt is cuda, if not all workload is bound with |
143 | * threads, CPU code is generated that tries to access GPU memory, |
144 | * which is illegal. This pass performs verification for this case. |
145 | * |
146 | * \param func The function to be verified. |
147 | * \return Success of memory verification. |
148 | */ |
149 | TVM_DLL bool VerifyMemory(const PrimFunc& func); |
150 | |
151 | /*! |
152 | * \brief Verify the correctness of a GPU code |
153 | * It will check the whether the amount of memory usage or the number of threads |
154 | * in a block exceeds the limit |
155 | * \param func The function to be checked |
156 | * \param constraints The dict to specify constraints to check. |
157 | * Possible keys are |
158 | * |
159 | * "max_local_memory_per_block": Total amount of local memory per block (in bytes). |
160 | * "max_shared_memory_per_block": Total amount of shared memory per block (in bytes). |
161 | * "max_threads_per_block": Maximum number of threads per block. |
162 | * "max_thread_x": Maximum length of threadIdx.x. |
163 | * "max_thread_y": Maximum length of threadIdx.y. |
164 | * "max_thread_z": Maximum length of threadIdx.z. |
165 | * |
166 | * If one key is missing in this argument, the pass won't check for that item. |
167 | * \return valid Whether it is a valid GPU code |
168 | * |
169 | */ |
170 | TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints); |
171 | |
172 | /*! |
173 | * \brief Verifies that the VTCM usage of the given prim_func is within the provided limit. |
174 | * \param func The function to be checked. |
175 | * \param limit The limit to check. |
176 | * \return true if the VTCM usage is within the provided limit. |
177 | */ |
178 | TVM_DLL bool VerifyVTCMLimit(const PrimFunc& func, Integer limit); |
179 | |
180 | /*! |
181 | * \brief Auto detect the block access region according to its body stmt |
182 | * It will detect the access region as an array in order of appearance in AST |
183 | * \param block The block to be detected |
184 | * \param buffer_var_map The outside buffers which may be accessed the block. |
185 | * It is a map from buffer var to the buffer. |
186 | * \return Array of access regions. |
187 | * There are three arrays of BufferRegion: |
188 | * - first: read regions |
189 | * - second: write regions |
190 | * - third: opaque regions |
191 | */ |
192 | TVM_DLL Array<Array<BufferRegion>> GetBlockAccessRegion(const Block& block, |
193 | const Map<Var, Buffer>& buffer_var_map); |
194 | |
195 | /*! |
196 | * \brief Auto detect the block read/write region according to its body stmt. An opaque access will |
197 | * be counted as both a read and a write access |
198 | * \param block The block to be detected |
199 | * \param buffer_var_map The outside buffers which may be accessed the block. |
200 | * It is a map from buffer var to the buffer |
201 | * \return An array only consisting of the read regions and write regions of the input block |
202 | */ |
203 | TVM_DLL Array<Array<BufferRegion>> GetBlockReadWriteRegion(const Block& block, |
204 | const Map<Var, Buffer>& buffer_var_map); |
205 | |
206 | /*! |
207 | * \brief Calculate the expresion complexity based on number of symbols it contains. |
208 | * \param expr The expr to be calculated. |
209 | */ |
210 | TVM_DLL size_t CalculateExprComplexity(const PrimExpr& expr); |
211 | |
212 | /*! |
213 | * \brief Calculate the constants size in bytes needed by the TIR allocates inside the TIR PrimFunc |
214 | * \param func The TIR PrimFunc for which the constants size to be calculated |
215 | * \param constant_byte_alignment The byte alignment required for each constant allocated |
216 | */ |
217 | TVM_DLL size_t CalculateConstantBytes(const PrimFunc& func, const Integer& constant_byte_alignment); |
218 | |
219 | /*! |
220 | * \brief Calculate the workspace size in bytes needed by the TIR allocates inside the TIR PrimFunc |
221 | * \param func The TIR PrimFunc for which the workspace size to be calculated |
222 | * \param workspace_byte_alignment The byte alignment required for each tensor allocated in this |
223 | * workspace |
224 | */ |
225 | TVM_DLL size_t CalculateWorkspaceBytes(const PrimFunc& func, |
226 | const Integer& workspace_byte_alignment); |
227 | |
228 | /*! |
229 | * \brief Calculate the allocated memory per scope in bytes needed inside the TIR PrimFunc |
230 | * \param func The TIR PrimFunc for which the the allocated memory size to be calculated |
231 | */ |
232 | TVM_DLL tvm::Map<String, Integer> CalculateAllocatedBytes(const PrimFunc& func); |
233 | |
234 | /*! |
235 | * \brief Detect the lowest common ancestor(LCA) of buffer access, including both high-level |
236 | * access(BufferLoad, BufferStore) and low-level access(Load, Store and opaque access). |
237 | * The LCA may be a For loop or a Block. |
238 | * \param func The PrimFunc to be detected. |
239 | * \return The Map from buffer to the LCA of all access to it. The lca is function root if the |
240 | * return stmt is NullOpt. |
241 | */ |
242 | TVM_DLL Map<Buffer, Optional<Stmt>> DetectBufferAccessLCA(const PrimFunc& func); |
243 | |
244 | /*! |
245 | * \brief Verify if the given TIR is well-formed. The verification includes: |
246 | * - Check if expressions not contain vars that is defined outside the block. |
247 | * \param func The PrimFunc to be verified. |
248 | * \param assert_mode The indicator if it raises an error when the function is not well-formed. |
249 | * \return Whether it is a well-formed TIR function. |
250 | */ |
251 | TVM_DLL bool VerifyWellFormed(const PrimFunc& func, bool assert_mode = true); |
252 | |
253 | /*! |
254 | * \brief Find the entry function of the given IRModule, i.e, functions marked by |
255 | * `tir::attr::kIsEntryFunc`, whose name is `main` or being the only PrimeFunc. |
256 | * \param mod The IRModule to find the entry function. |
257 | * \param result_g_var The result GlobalVar of the entry function. |
258 | * \return The entry function. |
259 | */ |
260 | const PrimFuncNode* FindEntryFunc(const IRModule& mod, GlobalVar* result_g_var); |
261 | |
262 | /*! |
263 | * \brief Find the "anchor block" of the given module. |
264 | * We define the anchor block to be the block with (1) an init statement and (2) having |
265 | * the biggest flops count. The latter condition is only used when there are multiple blocks |
266 | * with an init statement. |
267 | * For example, if the input module is conv2d + fused spatial blocks, conv2d is the anchor block. |
268 | * The input module may not contain more than one such block. For example, a module having |
269 | * two conv2d is not allowed as an input. |
270 | * However, a module created from winograd convolution has multiple blocks with an init statement |
271 | * (input transform, batched GEMM, and output transform). We use the second condition, the flops |
272 | * count, to determine that the batched GEMM block is the anchor block. |
273 | * \param mod The input TIR module. |
274 | * \return The anchor block if found, nullptr otherwise. |
275 | */ |
276 | const tir::BlockNode* FindAnchorBlock(const IRModule& mod); |
277 | |
278 | // Pass variants of verification analysis |
279 | // directly throws RuntimeError when verification fails. |
280 | namespace transform { |
281 | |
282 | using tvm::transform::Pass; |
283 | using tvm::transform::PassContext; |
284 | |
285 | /*! |
286 | * \brief Pass variant of VerifySSA. |
287 | * |
288 | * \returns The pass. |
289 | * \sa tvm::tir::VerifySSA |
290 | */ |
291 | TVM_DLL Pass VerifySSA(); |
292 | |
293 | /*! |
294 | * \brief Pass variant of VerifyMemory. |
295 | * |
296 | * \returns The pass. |
297 | * \sa tvm::tir::VerifyMemory |
298 | */ |
299 | TVM_DLL Pass VerifyMemory(); |
300 | |
301 | /*! |
302 | * \brief Pass variant of VerifyGPUCode. |
303 | * |
304 | * \param constraints The dict to specify constraints to check. |
305 | * |
306 | * \returns The pass. |
307 | * \sa tvm::tir::VerifyGPUCode |
308 | */ |
309 | TVM_DLL Pass VerifyGPUCode(Map<String, PrimExpr> constraints); |
310 | |
311 | /*! |
312 | * \brief Pass to checks if the size of the allocated vtcm memory satisfies the limit |
313 | * |
314 | * \param limit The limit to check. |
315 | * |
316 | * \returns The pass. |
317 | * \sa tvm::tir::CalculateAllocatedBytes |
318 | */ |
319 | TVM_DLL Pass VerifyVTCMLimit(const Integer& limit); |
320 | |
321 | /*! |
322 | * \brief Statically check TIR code for out of bounds array access. |
323 | * |
324 | * This analysis is conservative: it will only raise errors if it can prove |
325 | * that out of bounds access occurs. Cases that are uncertain do not raise |
326 | * errors. |
327 | * |
328 | * \returns The pass. |
329 | */ |
330 | TVM_DLL Pass OOBChecker(); |
331 | |
332 | } // namespace transform |
333 | } // namespace tir |
334 | } // namespace tvm |
335 | #endif // TVM_TIR_ANALYSIS_H_ |
336 | |