1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/tir/analysis.h
22 * \brief Analysis utilities and passes for TIR.
23 */
24#ifndef TVM_TIR_ANALYSIS_H_
25#define TVM_TIR_ANALYSIS_H_
26
27#include <tvm/ir/module.h>
28#include <tvm/ir/transform.h>
29#include <tvm/tir/expr.h>
30#include <tvm/tir/function.h>
31#include <tvm/tir/op_attr_types.h>
32#include <tvm/tir/stmt.h>
33
34#include <string>
35
36namespace tvm {
37namespace tir {
38
39/*!
40 * \brief Compare two expressions recursively and check if they are equal
41 * to each other without var remapping.
42 *
43 * This function does not remap variable bindings, it will not
44 * return true for (let x = 1 in x + 1) vs (let y = 1 in y + 1), unless x.same_as(y).
45 *
46 * Use StructuralEqual for such cases.
47 *
48 * Due to the restriction of not remapping variables, this function can run
49 * faster than StructuralEqual and can be used as a utility function during arithmetic
50 * simplifications.
51 *
52 * \sa StructuralEqual
53 */
54struct ExprDeepEqual {
55 public:
56 TVM_DLL bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const;
57};
58
59/*!
60 * \brief Visit the PrimFuncs in the IRModule
61 * \tparam FLambda The type of the PrimFunc visitor
62 * \param mod The IRModule to be visited
63 * \param fvisit The visitor to the PrimFuncs in the IRModule
64 */
65template <class FLambda>
66inline void VisitPrimFuncs(const IRModule& mod, FLambda fvisit) {
67 for (const auto& kv : mod->functions) {
68 const BaseFunc& base_func = kv.second;
69 if (const auto* prim_func = base_func.as<PrimFuncNode>()) {
70 fvisit(prim_func);
71 }
72 }
73}
74
75/*!
76 * \brief Estimate the FLOPs of a TIR fragment.
77 * \param stmt The TIR fragment to be estimated.
78 * \return The estimated FLOPs.
79 */
80TVM_DLL double EstimateTIRFlops(const Stmt& stmt);
81
82/*!
83 * \brief Estimate the FLOPs of TIRs in an IRModule.
84 * \param mod The IRModule to be estimated.
85 * \return The estimated FLOPs.
86 */
87TVM_DLL double EstimateTIRFlops(const IRModule& mod);
88
89/*!
90 * \brief Find undefined vars in the statement.
91 * \param stmt The function to be checked.
92 * \param defs The vars that is defined.
93 * \return Array of undefined vars.
94 */
95TVM_DLL Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);
96
97/*!
98 * \brief Find undefined vars in the expression.
99 * \param expr The expression to be checked.
100 * \return Array of undefined vars.
101 */
102TVM_DLL Array<Var> UndefinedVars(const PrimExpr& expr);
103
104/*!
105 * \brief Analyze the side effect
106 * \param expr The expression to be checked.
107 *
108 * \return CallEffectKind, can be kPure, kReadState or kUpdateState
109 */
110TVM_DLL CallEffectKind SideEffect(const PrimExpr& expr);
111
112/*!
113 * \brief Whether the given Stmt uses any var in the given variable set.
114 * \param stmt The Stmt to be checked.
115 * \param vset_contains The check function to see if a var is in the variable set.
116 * \return Whether `stmt` uses any var in the given variable set.
117 */
118TVM_DLL bool UsesVar(const Stmt& stmt, std::function<bool(const VarNode*)> vset_contains);
119
120/*!
121 * \brief Whether the given PrimExpr uses any var in the given variable set.
122 * \param expr The PrimExpr to be checked.
123 * \param vset_contains The check function to see if var is in the variable set.
124 * \return Whether `expr` uses any var in the given variable set.
125 */
126TVM_DLL bool UsesVar(const PrimExpr& expr, std::function<bool(const VarNode*)> vset_contains);
127
128/*!
129 * \brief Verifies whether the IR stmt or Expr is in SSA form.
130 * That is: each Var is defined and assigned once(in Let/For)
131 *
132 * \param func The function to be verified.
133 * \return Whether IR is in SSA form.
134 *
135 * \note All passes in TIR consume and produce SSA form.
136 */
137TVM_DLL bool VerifySSA(const PrimFunc& func);
138
139/*!
140 * \brief Verify if memory accesses are legal for a specific target device type.
141 *
142 * In the case that tgt is cuda, if not all workload is bound with
143 * threads, CPU code is generated that tries to access GPU memory,
144 * which is illegal. This pass performs verification for this case.
145 *
146 * \param func The function to be verified.
147 * \return Success of memory verification.
148 */
149TVM_DLL bool VerifyMemory(const PrimFunc& func);
150
151/*!
152 * \brief Verify the correctness of a GPU code
153 * It will check the whether the amount of memory usage or the number of threads
154 * in a block exceeds the limit
155 * \param func The function to be checked
156 * \param constraints The dict to specify constraints to check.
157 * Possible keys are
158 *
159 * "max_local_memory_per_block": Total amount of local memory per block (in bytes).
160 * "max_shared_memory_per_block": Total amount of shared memory per block (in bytes).
161 * "max_threads_per_block": Maximum number of threads per block.
162 * "max_thread_x": Maximum length of threadIdx.x.
163 * "max_thread_y": Maximum length of threadIdx.y.
164 * "max_thread_z": Maximum length of threadIdx.z.
165 *
166 * If one key is missing in this argument, the pass won't check for that item.
167 * \return valid Whether it is a valid GPU code
168 *
169 */
170TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints);
171
172/*!
173 * \brief Verifies that the VTCM usage of the given prim_func is within the provided limit.
174 * \param func The function to be checked.
175 * \param limit The limit to check.
176 * \return true if the VTCM usage is within the provided limit.
177 */
178TVM_DLL bool VerifyVTCMLimit(const PrimFunc& func, Integer limit);
179
180/*!
181 * \brief Auto detect the block access region according to its body stmt
182 * It will detect the access region as an array in order of appearance in AST
183 * \param block The block to be detected
184 * \param buffer_var_map The outside buffers which may be accessed the block.
185 * It is a map from buffer var to the buffer.
186 * \return Array of access regions.
187 * There are three arrays of BufferRegion:
188 * - first: read regions
189 * - second: write regions
190 * - third: opaque regions
191 */
192TVM_DLL Array<Array<BufferRegion>> GetBlockAccessRegion(const Block& block,
193 const Map<Var, Buffer>& buffer_var_map);
194
195/*!
196 * \brief Auto detect the block read/write region according to its body stmt. An opaque access will
197 * be counted as both a read and a write access
198 * \param block The block to be detected
199 * \param buffer_var_map The outside buffers which may be accessed the block.
200 * It is a map from buffer var to the buffer
201 * \return An array only consisting of the read regions and write regions of the input block
202 */
203TVM_DLL Array<Array<BufferRegion>> GetBlockReadWriteRegion(const Block& block,
204 const Map<Var, Buffer>& buffer_var_map);
205
206/*!
207 * \brief Calculate the expresion complexity based on number of symbols it contains.
208 * \param expr The expr to be calculated.
209 */
210TVM_DLL size_t CalculateExprComplexity(const PrimExpr& expr);
211
212/*!
213 * \brief Calculate the constants size in bytes needed by the TIR allocates inside the TIR PrimFunc
214 * \param func The TIR PrimFunc for which the constants size to be calculated
215 * \param constant_byte_alignment The byte alignment required for each constant allocated
216 */
217TVM_DLL size_t CalculateConstantBytes(const PrimFunc& func, const Integer& constant_byte_alignment);
218
219/*!
220 * \brief Calculate the workspace size in bytes needed by the TIR allocates inside the TIR PrimFunc
221 * \param func The TIR PrimFunc for which the workspace size to be calculated
222 * \param workspace_byte_alignment The byte alignment required for each tensor allocated in this
223 * workspace
224 */
225TVM_DLL size_t CalculateWorkspaceBytes(const PrimFunc& func,
226 const Integer& workspace_byte_alignment);
227
228/*!
229 * \brief Calculate the allocated memory per scope in bytes needed inside the TIR PrimFunc
230 * \param func The TIR PrimFunc for which the the allocated memory size to be calculated
231 */
232TVM_DLL tvm::Map<String, Integer> CalculateAllocatedBytes(const PrimFunc& func);
233
234/*!
235 * \brief Detect the lowest common ancestor(LCA) of buffer access, including both high-level
236 * access(BufferLoad, BufferStore) and low-level access(Load, Store and opaque access).
237 * The LCA may be a For loop or a Block.
238 * \param func The PrimFunc to be detected.
239 * \return The Map from buffer to the LCA of all access to it. The lca is function root if the
240 * return stmt is NullOpt.
241 */
242TVM_DLL Map<Buffer, Optional<Stmt>> DetectBufferAccessLCA(const PrimFunc& func);
243
244/*!
245 * \brief Verify if the given TIR is well-formed. The verification includes:
246 * - Check if expressions not contain vars that is defined outside the block.
247 * \param func The PrimFunc to be verified.
248 * \param assert_mode The indicator if it raises an error when the function is not well-formed.
249 * \return Whether it is a well-formed TIR function.
250 */
251TVM_DLL bool VerifyWellFormed(const PrimFunc& func, bool assert_mode = true);
252
253/*!
254 * \brief Find the entry function of the given IRModule, i.e, functions marked by
255 * `tir::attr::kIsEntryFunc`, whose name is `main` or being the only PrimeFunc.
256 * \param mod The IRModule to find the entry function.
257 * \param result_g_var The result GlobalVar of the entry function.
258 * \return The entry function.
259 */
260const PrimFuncNode* FindEntryFunc(const IRModule& mod, GlobalVar* result_g_var);
261
262/*!
263 * \brief Find the "anchor block" of the given module.
264 * We define the anchor block to be the block with (1) an init statement and (2) having
265 * the biggest flops count. The latter condition is only used when there are multiple blocks
266 * with an init statement.
267 * For example, if the input module is conv2d + fused spatial blocks, conv2d is the anchor block.
268 * The input module may not contain more than one such block. For example, a module having
269 * two conv2d is not allowed as an input.
270 * However, a module created from winograd convolution has multiple blocks with an init statement
271 * (input transform, batched GEMM, and output transform). We use the second condition, the flops
272 * count, to determine that the batched GEMM block is the anchor block.
273 * \param mod The input TIR module.
274 * \return The anchor block if found, nullptr otherwise.
275 */
276const tir::BlockNode* FindAnchorBlock(const IRModule& mod);
277
278// Pass variants of verification analysis
279// directly throws RuntimeError when verification fails.
280namespace transform {
281
282using tvm::transform::Pass;
283using tvm::transform::PassContext;
284
285/*!
286 * \brief Pass variant of VerifySSA.
287 *
288 * \returns The pass.
289 * \sa tvm::tir::VerifySSA
290 */
291TVM_DLL Pass VerifySSA();
292
293/*!
294 * \brief Pass variant of VerifyMemory.
295 *
296 * \returns The pass.
297 * \sa tvm::tir::VerifyMemory
298 */
299TVM_DLL Pass VerifyMemory();
300
301/*!
302 * \brief Pass variant of VerifyGPUCode.
303 *
304 * \param constraints The dict to specify constraints to check.
305 *
306 * \returns The pass.
307 * \sa tvm::tir::VerifyGPUCode
308 */
309TVM_DLL Pass VerifyGPUCode(Map<String, PrimExpr> constraints);
310
311/*!
312 * \brief Pass to checks if the size of the allocated vtcm memory satisfies the limit
313 *
314 * \param limit The limit to check.
315 *
316 * \returns The pass.
317 * \sa tvm::tir::CalculateAllocatedBytes
318 */
319TVM_DLL Pass VerifyVTCMLimit(const Integer& limit);
320
321/*!
322 * \brief Statically check TIR code for out of bounds array access.
323 *
324 * This analysis is conservative: it will only raise errors if it can prove
325 * that out of bounds access occurs. Cases that are uncertain do not raise
326 * errors.
327 *
328 * \returns The pass.
329 */
330TVM_DLL Pass OOBChecker();
331
332} // namespace transform
333} // namespace tir
334} // namespace tvm
335#endif // TVM_TIR_ANALYSIS_H_
336