1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#ifndef TVM_TIR_SCHEDULE_ANALYSIS_H_
20#define TVM_TIR_SCHEDULE_ANALYSIS_H_
21
22#include <tvm/arith/analyzer.h>
23#include <tvm/ir/op.h>
24#include <tvm/tir/index_map.h>
25#include <tvm/tir/schedule/schedule.h>
26#include <tvm/tir/schedule/state.h>
27
28#include <tuple>
29#include <unordered_map>
30#include <unordered_set>
31#include <utility>
32#include <vector>
33
34#include "../../runtime/thread_storage_scope.h"
35
36namespace tvm {
37namespace tir {
38
39/******** Verification ********/
40/*!
41 * \brief Verifies the sref tree state is consistent with the IR
42 * \param self The schedule state containing the sref to be verified
43 * \throw An exception will be thrown if the sref tree is not valid
44 */
45void VerifySRefTree(const ScheduleState& self);
46/*!
47 * \brief Verifies the cached flags in the schedule state, including:
48 * - affine_binding
49 * - region_cover
50 * - stage_pipeline
51 * \param self The schedule state to be verified
52 * \throw An exception will be thrown if some srefs are not valid
53 */
54void VerifyCachedFlags(const ScheduleState& self);
55
56/******** IR Module ********/
57/*!
58 * \brief Get PrimFunc and GlobalVar that the root block belongs to
59 * \param mod The IRModule
60 * \param root_block The root block of the PrimFunc
61 * \param result_g_var The result GlobalVar
62 * \return The result PrimFunc where the root block belongs to
63 * \note This function returns the pointer instead of ObjectRef to avoid later copy-on-write
64 */
65const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_block,
66 GlobalVar* result_g_var);
67
68/*!
69 * \brief Get the root node of the sref tree, which is the root block of the PrimFunc.
70 * \param sref The given sref.
71 * \return The root node of the sref tree which contains the given node.
72 */
73StmtSRef GetSRefTreeRoot(const StmtSRef& sref);
74
75/******** Scope ********/
76/*!
77 * \brief Checks if scope the specified sref is in is a stage-pipeline and return it
78 * \param self The schedule state
79 * \param sref The sref whose scope is to be checked
80 * \param require_stage_pipeline A boolean indicating whether to check stage pipeline
81 * \throw ScheduleError if
82 * 1) the sref has been the root of the AST (so it has no scope root), or
83 * 2) require_stage_pipeline = true, but its scope root is not a stage pipeline
84 * \return The block sref to the scope root
85 */
86StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref, bool require_stage_pipeline);
87
88/*!
89 * \brief The information of a block scope, including the leaf blocks,
90 * as well as the loop types (spatial, reduction) for each loop in the scope.
91 */
92struct ScopeBlockLoopInfo {
93 /*! \brief A list of the leaf blocks, from left to right */
94 std::vector<BlockRealize> realizes;
95 /*! \brief The loop vars bound to spatial block iters */
96 std::unordered_set<const VarNode*> spatial_vars;
97 /*! \brief The loop vars bound to non-spatial block iters */
98 std::unordered_set<const VarNode*> non_spatial_vars;
99};
100
101/*!
102 * \brief Inspect the scope of the given sref
103 * \param scope_block The root block of the scope
104 * \return The information of the scope
105 */
106ScopeBlockLoopInfo GetScopeBlockLoopInfo(const Block& scope_block);
107
108/*!
109 * \brief Checks whether the block is a complete block under the scope
110 * \param self The schedule state
111 * \param block_sref The block to be checked
112 * \param scope_root_sref The sref to the root block of the scope that `block_sref` is in
113 * \return A boolean indicating if the block is a complete block
114 * \note Definition of a complete block:
115 * 1) All block vars are data parallel
116 * 2) Dominant: the block is the only writer of its output,
117 * dominating the reader of its output buffers
118 * 3) No overlap between the buffers the block reads and writes
119 */
120bool IsCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref,
121 const StmtSRef& scope_root_sref);
122
123/*!
124 * \brief Check if the block is a complete block under the scope
125 * \param self The schedule state
126 * \param block_sref The sref to the block whose completeness is to be checked
127 * \param scope_root_sref The scope root of the block
128 * \throw ScheduleError If the block is not a complete block
129 */
130void CheckCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref,
131 const StmtSRef& scope_root_sref);
132
133/*!
134 * \brief Check whether the block is a reduction block under the scope
135 * \param self The schedule state
136 * \param block_sref The block to be checked
137 * \param scope_root_sref The sref to the root block of the scope that `block_sref` is in
138 * \return A boolean indicating if the block is a reduction block
139 * \note Definition of a reduction block:
140 * 1) The block has the `init` statement
141 * 2) All the block bindings are quasi-affine expressions
142 * 3) All block vars are either data parallel block vars or reduction block vars
143 * 4) Dominant: the block is the only writer of its output, dominating the reader of its output
144 * buffers
145 * 5) The reduction block vars are not used to index the output buffers
146 */
147bool IsReductionBlock(const ScheduleState& self, const StmtSRef& block_sref,
148 const StmtSRef& scope_root_sref);
149
150/*!
151 * \brief Check if the block is a reduction block under the scope
152 * \param self The schedule state
153 * \param block_sref The sref of the block to be checked
154 * \param scope_root_sref The scope root of the block
155 * \throw ScheduleError If the block is not a reduction block
156 */
157void CheckReductionBlock(const ScheduleState& self, const StmtSRef& block_sref,
158 const StmtSRef& scope_root_sref);
159
160/*!
161 * \brief Check if the block is a complete block or a reduction block under the scope
162 * \param self The schedule state
163 * \param block_sref The sref of the block to be checked
164 * \param scope_root_sref The scope root of the block
165 * \throw ScheduleError If the block is neither a complete block nor a reduction block
166 */
167void CheckCompleteOrReductionBlock(const ScheduleState& self, const StmtSRef& block_sref,
168 const StmtSRef& scope_root_sref);
169
170/*!
171 * \brief Check the subtree compact dataflow property. The scope root may have one or more subtrees
172 * rooted at its direct children, and this property requires all the blocks of the subtree
173 * that the specified sref is in to be local complete block or local reduction block.
174 * \param self The schedule state
175 * \param subtree_root The sref of the subtree root to be checked
176 */
177void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subtree_root);
178/*!
179 * \brief Check if the block is an output block, i.e. the block writes to at least a buffer that is
180 * not allocated under the current scope
181 * \param self The schedule state
182 * \param block_sref The block to be checked
183 * \param scope_root_sref The scope root of the block
184 * \return A boolean flag indicating if the block is an output block
185 */
186bool IsOutputBlock(const ScheduleState& self, const StmtSRef& block_sref,
187 const StmtSRef& scope_root_sref);
188
189/*!
190 * \brief Check if the block is not an output block, i.e. all the buffers the block writes to
191 * are allocated under the current scope
192 * \param self The schedule state
193 * \param block_sref The block to be checked
194 * \param scope_root_sref The scope root of the block
195 * \throw ScheduleError if the block is an output block
196 */
197void CheckNotOutputBlock(const ScheduleState& self, const StmtSRef& block_sref,
198 const StmtSRef& scope_root_sref);
199
200/*!
201 * \brief Extracts the types of the block vars
202 * \param block_sref The block to be checked
203 * \return A vector of types of the block vars
204 */
205std::vector<IterVarType> GetBlockVarTypes(const StmtSRef& block_sref);
206
207/*!
208 * \brief Checks if a block could be considered as a "write cache"
209 * \param block_sref The block to be checked
210 * \return A boolean flag indicating if the block is a write cache
211 */
212bool IsWriteCache(const StmtSRef& block_sref);
213
214/******** Binding ********/
215/*!
216 * \brief Verifies if the block binding in a specific BlockRealize is an affine binding.
217 * The binding can be represented as an injective affine map from the loop iterators.
218 * \param realize The BlockRealize to be analyzed
219 * \param loop_var_ranges The ranges of the loop variables
220 * \param analyzer The analyzer
221 * \return A boolean flag indicating if the binding is affine
222 */
223bool IsAffineBinding(const BlockRealize& realize, const Map<Var, Range>& loop_var_ranges,
224 arith::Analyzer* analyzer);
225
226/*!
227 * \brief Check whether a block has an affine binding using the cached flag, and throw an exception
228 * if the block does not have an affine binding.
229 * \param self The schedule state
230 * \param block The block to be checked
231 * \throw ScheduleError If the input block does not have an affine binding
232 */
233void CheckAffineBinding(const ScheduleState& self, Block block);
234
235/*!
236 * \brief Check whether a block has an affine binding under the high exclusive sref node,
237 * throw an exception if the block does not have an affine binding.
238 * \param self The schedule state
239 * \param block The block to be checked
240 * \param high_exclusive The highest sref node
241 * \throw ScheduleError If the input block does not have an affine binding
242 */
243void CheckPartialAffineBinding(const ScheduleState& self, Block block,
244 const Optional<StmtSRef>& high_exclusive);
245
246/*!
247 * \brief Extracts the ranges of loop variables in a path of the sref tree
248 * \param low_inclusive The lowest node in the path
249 * \param high_exclusive The highest node in the path, defaults to the scope root if not specified
250 * \param extra_relax_scope If the scope is not global, the method will look beyond the limit and
251 * retrieve extra domains. For example,
252 * - if the storage scope is warp, it will look upwards for threadIdx.x
253 * - if the storage scope is shared, it will look for threadIdx.x/y/z
254 * \return The loop domain
255 */
256Map<Var, Range> LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive,
257 const Optional<StmtSRef>& high_exclusive = NullOpt,
258 const runtime::StorageScope& extra_relax_scope = //
259 runtime::StorageScope{runtime::StorageRank::kGlobal, ""});
260
261/*!
262 * \brief Returns the block var binding
263 * \param realize The BlockRealize to be analyzed
264 * \return The block var binding
265 */
266Map<Var, PrimExpr> GetBindings(const BlockRealize& realize);
267
268/*!
269 * \brief Get the vars involved in the bindings of data parallel block vars and reduction block
270 * vars, respectively
271 * \param block_realize The BlockRealize to be analyzed
272 * \param data_par_vars The vars that appear in the binding of any data parallel block iter
273 * \param reduce_vars The vars that appear in the binding of any reduction block iter
274 * \return A boolean indicating whether the block has block iters that is neither a data parallel
275 * block iter nor a reduction block iter
276 */
277bool GetVarsTouchedByBlockIters(const BlockRealize& block_realize,
278 std::unordered_set<const VarNode*>* data_par_vars,
279 std::unordered_set<const VarNode*>* reduce_vars);
280
281/******** Loop properties ********/
282/*!
283 * \brief Check the loop starts with zero.
284 * \param self The schedule state
285 * \param loop_sref The StmtSRef that points to the loop to be checked
286 * \param analyzer The arithmetic analyzer
287 * \throw ScheduleError If the loop doesn't starts with zero.
288 */
289void CheckLoopStartsWithZero(const ScheduleState& self, const StmtSRef& loop_sref,
290 arith::Analyzer* analyzer);
291
292/*!
293 * \brief Check whether a block has a trivial binding, i.e. each block var is bound to a outer loop,
294 * from outer to inner.
295 * \param self The schedule state
296 * \param block_sref The block to be checked
297 * \throw ScheduleError If the block does not have trivial bindings
298 */
299void CheckBlockHasTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref);
300
301/******** Block-loop relation ********/
302
303/*!
304 * \brief Gets StmtSRefs of leaf blocks of a scope where a specific block/loop is in
305 * \param self The schedule state
306 * \param parent_sref The StmtSRef that points to the parent block/loop
307 * \return A list of StmtSRefs of leaf block
308 */
309Array<StmtSRef> GetChildBlockSRefOnSRefTree(const ScheduleState& self, const StmtSRef& parent_sref);
310
311/*!
312 * \brief Gets the BlockRealize of the leaf blocks of a scope where a specific block/loop is in
313 * \param parent_sref The StmtSRef that points to the parent block/loop
314 * \return A list of leaf BlockRealize
315 */
316Array<BlockRealize> GetChildBlockRealizeOnSRefTree(const StmtSRef& parent_sref);
317
318/*!
319 * \brief Get the BlockRealize of the single child block of the block or loop specified by
320 * `parent_sref` on SRef tree, or throw an exception if there is 0 or multiple child blocks
321 * \param self The schedule state
322 * \param parent_sref The StmtSRef that points to the parent block/loop
323 * \return The BlockRealize of the single child block
324 * \throw ScheduleError If there is 0 or multiple child blocks
325 */
326BlockRealize CheckGetSingleChildBlockRealizeOnSRefTree(const ScheduleState& self,
327 const StmtSRef& parent_sref);
328
329/*!
330 * \brief Get the BlockRealize of the input block
331 * \param self The schedule state
332 * \param block_sref The StmtSRef of the queried block
333 * \return The BlockRealize of the input block
334 */
335BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sref);
336
337/*!
338 * \brief Get the IterVarType of the specific loop, according to the blocks it's bound to
339 * \param loop_sref The loop to be checked
340 * \return The IterVarType of the specific loop
341 */
342IterVarType GetLoopIterType(const StmtSRef& loop_sref);
343
344/*!
345 * \brief Get the lowest common ancestor of an array of blocks or loops on the sref tree
346 * \param srefs The block srefs or loop srefs whose lowest common ancestor is to be queried
347 * \return The lowest common ancestor of the input block srefs or loop srefs
348 * \note The input array is required to have at least one sref
349 */
350StmtSRef GetSRefLowestCommonAncestor(const Array<StmtSRef>& srefs);
351
352/*!
353 * \brief Checks if the given block has been applied by multi-level tiling. We check this by
354 * examine the block's annotation.
355 * \param block_sref The block to be checked
356 * \return A boolean indicating whether the block has been multi-level tiled.
357 */
358bool HasBeenMultiLevelTiled(const StmtSRef& block_sref);
359
360/*!
361 * \brief Collect all the feasible compute-at locations of the input block
362 * \param self The schedule state
363 * \param block_sref The block whose compute-at locations are to be collected
364 * \return All the feasible compute-at locations of the input block, given as an array of loop srefs
365 * and an array of their indices among the outer loops of the input block
366 */
367std::pair<Array<StmtSRef>, std::vector<int>> CollectComputeLocation(const ScheduleState& self,
368 const StmtSRef& block_sref);
369
370/******** Producer-consumer relation ********/
371
372/*!
373 * \brief Get the producer blocks to the given block under the given scope
374 * \param block_sref The block whose producers are to be retrieved
375 * \param scope The block scope where the given block is in
376 * \return The producer blocks of the specified block
377 */
378Array<StmtSRef> GetProducers(const StmtSRef& block_sref, const BlockScope& scope);
379
380/*!
381 * \brief Get the consumer blocks to the given block under the given scope
382 * \param block_sref The block whose consumers are to be retrieved
383 * \param scope The block scope where the given block is in
384 * \return The consumer blocks of the specified block
385 */
386Array<StmtSRef> GetConsumers(const StmtSRef& block_sref, const BlockScope& scope);
387
388/*!
389 * \brief A solution to split a ordered list of subtrees into two parts,
390 * where producers are on the LHS and consumers are on the RHS.
391 * For example, subtree[0, 3) are on the LHS, and subtree[3, 6) are on the RHS.
392 */
393struct ProducerConsumerSplit {
394 /*! \brief Indicates that all producers fall into `subtrees[0, last_producer_position]` */
395 int last_producer_position;
396 /*! \brief Indicates that all consumers fall into `subtrees[first_consumer_position, ...)` */
397 int first_consumer_position;
398 /*! \brief The number of given producers visited in `subtrees` */
399 int n_producers_visited;
400 /*! \brief The number of given consumers visited in `subtrees` */
401 int n_consumers_visited;
402 /*!
403 * \brief Find a split among the given `subtree`
404 * \param state The schedule state
405 * \param subtrees The ordered list of subtrees to be split
406 * \param producer_block_srefs The producers
407 * \param consumer_block_srefs The consumers
408 * \param block2realize If not null, the corresponding BlockRealize to each block in the scope
409 * will be saved in this map
410 * \return The valid split points are (last_producer_position, first_consumer_position]
411 * \throw ScheduleError is not valid split is found
412 */
413 static ProducerConsumerSplit Find(
414 const ScheduleState& state, const Array<Stmt>& subtrees,
415 const Array<StmtSRef>& producer_block_srefs, const Array<StmtSRef>& consumer_block_srefs,
416 std::unordered_map<const BlockNode*, const BlockRealizeNode*>* block2realize);
417};
418
419/******** Block-buffer relation ********/
420
421/*!
422 * \brief Get the n-th read or write buffer of the given block.
423 * \param self The schedule state.
424 * \param block The queried block.
425 * \param n The index of the queried buffer.
426 * \param index_type The type of the buffer index, kRead or kWrite.
427 * \return The buffer of the n-th read/write region of the block.
428 * \throw ScheduleError If the buffer index is out of bound.
429 */
430Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n,
431 BufferIndexType index_type);
432
433/*!
434 * \brief Get the n-th read or write buffer of the given block.
435 * \param self The schedule state.
436 * \param block The queried block.
437 * \param n The index of the queried buffer.
438 * \param index_type The type of the buffer index, kRead or kWrite.
439 * \return The n-th read/write region of the block.
440 * \throw ScheduleError If the buffer index is out of bound.
441 */
442BufferRegion GetNthAccessBufferRegion(const ScheduleState& self, const Block& block, int n,
443 BufferIndexType index_type);
444
445/*!
446 * \brief Find the defining site of the buffer in the given block and its ancestors
447 * \param block_sref The block sref
448 * \param buffer The buffer
449 * \return The defining site of the buffer and whether the buffer is allocated (otherwise the
450 * buffer is from match_buffer).
451 */
452std::pair<Optional<StmtSRef>, bool> GetBufferDefiningSite(const StmtSRef& block_sref,
453 const Buffer& buffer);
454
455/******** Reduction Block Related ********/
456
457/*!
458 * \brief Get the init values and the BufferStore updates from the input reduction block
459 * \param self The schedule state, used for error reporting
460 * \param block The block from which the init values and BufferStore updates are extracted from
461 * \return The extracted init values and BufferStore updates
462 * \throw ScheduleError If rfactor or cross-thread reduction cannot be applied to the block
463 */
464std::pair<Array<PrimExpr>, Array<BufferStore>> GetInitValuesAndUpdatesFromReductionBlock(
465 const Optional<ScheduleState>& self, Block block);
466
467/*!
468 * \brief Check whether the input array of IterVars only contains data-parallel and reduction block
469 * iters
470 * \param iters The input array of IterVars to be checked
471 * \return A boolean indicating whether the input array of IterVars only contains data-parallel and
472 * reduction block iters
473 */
474bool ContainsOnlyDataParAndReductionBlockIter(const Array<IterVar>& iters);
475
476/*!
477 * \brief Check whether the block's reduction block iters are not used to index the block's output
478 * buffers
479 * \param block The block to be checked
480 * \return A boolean indicating whether the block's reduction block iters are not used to index the
481 * block's output buffer
482 */
483bool ReductionIterNotIndexOutputBuffer(const Block& block);
484
485/*!
486 * \brief Given a list of reduction identities and a list of reduction combiners, detect the
487 * corresponding commutative reducer, and extract the combiner LHS values and combiner RHS values
488 * \param self The schedule state
489 * \param identities The reduction identities to be analyzed
490 * \param combiners The reduction combiners to be analyzed
491 * \return The corresponding CommReducer, combiner LHS values and combiner RHS values
492 * \throw ScheduleError If no corresponding commutative reducer can be matched
493 */
494std::tuple<CommReducer, Array<PrimExpr>, Array<PrimExpr>> GetReducerAndCombinerLhsRhs(
495 const Optional<ScheduleState>& self, const Array<PrimExpr>& identities,
496 const Array<BufferStore>& combiners);
497
498/******** Commutative Reducer ********/
499
500/*!
501 * \brief Get the list of the registered reducer-getter functions
502 * \return The list of the registered reducer-getter functions
503 * \sa ReducerRegistry
504 */
505std::vector<runtime::TypedPackedFunc<Optional<CommReducer>(Array<PrimExpr>)>> GetReducerGetters();
506
507/*!
508 * \brief Given the input identities and the combiner BufferStores of a reduction, extract the
509 * corresponding commutative reducer, LHS values and RHS values, if possible.
510 * \param identities The identities of the reduction
511 * \param combiners The combiners of the reduction
512 * \param result_reducer The extracted CommReducer
513 * \param lhs The extracted LHS values of the reducer
514 * \param rhs The extracted RHS values of the reducer
515 * \return A boolean indicating whether a corresponding commutative reducer is found
516 */
517bool FromIdentityCombiner(const Array<PrimExpr>& identities, const Array<BufferStore>& combiners,
518 CommReducer* result_reducer, Array<PrimExpr>* lhs, Array<PrimExpr>* rhs);
519
520/******** Misc ********/
521
522/*!
523 * \brief Check whether the input storage scope string is valid. Throw an error if not.
524 * \param self The schedule state
525 * \param storage_scope The storage scope string to be checked
526 * \throw ScheduleError If the input storage scope is not valid
527 */
528void CheckStorageScope(const ScheduleState& self, String storage_scope);
529
530/*!
531 * \brief Checks if a block could be successfully computed inline into its consumer
532 * \param self The schedule state
533 * \param block_sref The block to be checked
534 * \return A boolean indicating whether the block could be successfully computed inline
535 */
536bool CanComputeInline(const ScheduleState& self, const StmtSRef& block_sref);
537
538/*!
539 * \brief Checks if a block could be successfully computed inline into its producer
540 * \param self The schedule state
541 * \param block_sref The block to be checked
542 * \return A boolean indicating whether the block could be successfully computed inline
543 */
544bool CanReverseComputeInline(const ScheduleState& self, const StmtSRef& block_sref);
545
546/*!
547 * \brief Checks if a producer block could be successfully computed at the specific loop.
548 * \param self The schedule state
549 * \param block_sref The block to be moved
550 * \param loop_sref The loop where the block to be moved to
551 * \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1
552 * \return A boolean indicating whether the block could be successfully compute at the specific loop
553 */
554bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
555 bool preserve_unit_loops);
556
557/*!
558 * \brief Checks if a consumer block could be successfully computed at the specific loop.
559 * \param self The schedule state
560 * \param block_sref The block to be moved
561 * \param loop_sref The loop where the block to be moved to
562 * \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1
563 * \return A boolean indicating whether the block could be successfully reverse compute at the
564 * specific loop
565 */
566bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref,
567 const StmtSRef& loop_sref, bool preserve_unit_loops);
568
569/*!
570 * \brief Provided the access pattern to a buffer, suggest one of the possible layout
571 * transformation to minimize the locality of the access pattern.
572 * \param buffer The buffer to be transformed
573 * \param indices The access pattern to the buffer
574 * \param loops The loops above the buffer
575 * \param predicate The predicate of the access
576 * \param analyzer Arithmetic analyzer
577 */
578Optional<IndexMap> SuggestIndexMap(const Buffer& buffer, const Array<PrimExpr>& indices,
579 const Array<For>& loops, const PrimExpr& predicate,
580 arith::Analyzer* analyzer);
581
582/*!
583 * \brief Checks if the given AST contains the specific operators
584 * \param stmt The AST statement to be checked
585 * \param ops The list of operators to be checked
586 * \return A boolean indicating whether the AST contains the specific operators
587 */
588bool HasOp(const Stmt& stmt, const Array<Op>& ops);
589
590/*!
591 * \brief Checks if the given AST statement contains if-then-else, including
592 * 1) IfThenElse statement
593 * 2) Select expression
594 * 3) The operator `tir.if_then_else`
595 * 4) non-constant-true Block predicates
596 * \param stmt The AST statement to be checked
597 * \return A boolean indicating whether the statement contains the if-then-else pattern
598 */
599bool HasIfThenElse(const Stmt& stmt);
600
601/*!
602 * \brief Given the read/write region, extract the pattern of their index correspondence
603 * namely, the mapping from read index to the write index.
604 * \param read_region The read region
605 * \param write_region The write region
606 * \return A tuple of booleans, the extracted pattern
607 * 0) exists: if the pattern is found
608 * 1) surjective: if the pattern is surjective, i.e. each write index is mapped at least once
609 * e.g. A[i, j] = B[i, i, j]
610 * 2) injective: if the pattern is injective, i.e. each write index is mapped at most once.
611 * e.g. A[i, j] = B[i]
612 * 3) ordered: if the mapping is ordered
613 * 4) no_const_read: if there is no constant indexing in the read indices,
614 * e.g. A[i, j] = B[0, i, j]
615 * 5) no_shift_read: if there is no constant shift in the read indices,
616 * e.g. A[i, j] = B[i + 1, j]
617 */
618std::tuple</*exists=*/bool,
619 /*surjective=*/bool,
620 /*injective=*/bool,
621 /*ordered=*/bool,
622 /*no_const_read=*/bool,
623 /*no_shift_read=*/bool>
624AnalyzeReadWritePattern(const BufferRegion& read_region, const BufferRegion& write_region);
625
626/*!
627 * \brief Check if the block is a data parallel block, i.e. all the block vars are data parallel
628 * \param block_sref The block to be checked
629 * \return A boolean flag indicating if the block is a data parallel block
630 */
631bool IsSpatial(const StmtSRef& block_sref);
632
633/*!
634 * \brief Check whether a block has a trivial binding, i.e. each block var is bound to a outer loop,
635 * from outer to inner.
636 * \param self The schedule state
637 * \param block_sref The block to be checked
638 * \return A boolean flag indicating if the block has a trivial binding
639 */
640bool IsTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref);
641
642/*!
643 * \brief Checks if the given block has data reuse opportunity and thus multi-level tiling is
644 * beneficial.
645 * \param self The schedule state
646 * \param block_sref The block to be checked
647 * \return A boolean indicating whether the block has data reuse opportunity
648 */
649bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef& block_sref);
650
651/*!
652 * \brief Checks if all the blocks in the PrimFunc is spatial
653 * \param func The PrimFunc to be checked
654 * \return A boolean indicating whether all the blocks in the PrimFunc is spatial
655 */
656bool IsSpatialPrimFunc(const PrimFunc& func);
657
658/*!
659 * \brief Checks if the rfactor or cross thread reduction is beneficial to the given block.
660 * \param self The schedule state.
661 * \param block_sref The block to be checked.
662 * \param max_parallel_extent The maximum parallel jobs on the target.
663 * \param max_parallel_basic The maximum cores on the target.
664 * \return A boolean indicating whether the operation is beneficial.
665 */
666bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, //
667 const tir::StmtSRef& block_sref, //
668 int64_t max_parallel_extent, //
669 int64_t max_parallel_basic);
670
671/*!
672 * \brief Analyze the buffer region under the sref tree path [dom_low_inclusive, dom_high_exclusive)
673 * Relaxation of the region may be used in upper-bound analysis, i.e. some extra region may be added
674 * to the result.
675 * \param region The buffer region to be analyzed
676 * \param dom_low_inclusive The lowest node in the sref tree path
677 * \param dom_high_exclusive The highest node in the sref tree path
678 * \return An n-dimensional integer set
679 */
680Array<arith::IntSet> AnalyzeRegionUpperBound(const BufferRegion& region, const PrimExpr& predicate,
681 const StmtSRef& dom_low_inclusive,
682 const StmtSRef& dom_high_exclusive,
683 arith::Analyzer* analyzer);
684
685/*!
686 * \brief Analyze the buffer region under the sref tree path [dom_low_inclusive, dom_high_exclusive)
687 * Some subregion may be discarded during the lower-bound analysis.
688 * \param realize The block realize that touches the buffer region
689 * \param region The buffer region to be analyzed
690 * \param dom_low_inclusive The lowest node in the sref tree path
691 * \param dom_high_exclusive The highest node in the sref tree path
692 * \param analyzer The analyzer
693 * \return An n-dimensional integer set
694 */
695Array<arith::IntSet> AnalyzeRegionLowerBound(const BufferRegion& region, const PrimExpr& predicate,
696 const StmtSRef& dom_low_inclusive,
697 const StmtSRef& dom_high_exclusive,
698 arith::Analyzer* analyzer);
699
700/*!
701 * \brief Check if buffer indices are all Vars and extr
702 * \param buffer_access The BufferLoad or BufferStore
703 * \return The indices if the indices are all Vars, otherwise NullOpt
704 */
705template <typename T>
706Optional<Array<Var>> CheckTrivialBufferIndices(const T& buffer_access) {
707 Array<Var> indices;
708 for (const PrimExpr& index : buffer_access->indices) {
709 const VarNode* var = index.as<VarNode>();
710 if (var == nullptr) {
711 return NullOpt;
712 }
713 indices.push_back(GetRef<Var>(var));
714 }
715 return indices;
716}
717
718/*!
719 * \brief Simplify non-trivial expressions
720 * \param expr The expression to be simplified
721 * \param analyzer The analyzer
722 * \return The simplified expression
723 *
724 * During scheduling, we often need preserve block iters in trivial expressions that can be
725 * simplified to constant values for further scheduling and analysis because simplifing away the
726 * block iters may result in loss of information for further analysis.
727 */
728PrimExpr SimplifyNonTrivialExpr(const PrimExpr& expr, arith::Analyzer* analyzer);
729
730/*! \brief Necessary information used for tensorization */
731class TensorizeInfoNode : public Object {
732 public:
733 /*! \brief Maps loops in a target block to the ones in an intrinsic description */
734 Map<tir::StmtSRef, tir::For> loop_map;
735 /*! \brief Maps loops in an intrinsic description to its index, outer to inner */
736 Map<tir::For, Integer> desc_loop_indexer;
737 /*! \brief Optional padded extents of the block iters when padding is needed to match the
738 * intrinsic description
739 */
740 Optional<Array<Integer>> block_iter_paddings;
741
742 void VisitAttrs(AttrVisitor* v) {
743 v->Visit("loop_map", &loop_map);
744 v->Visit("desc_loop_indexer", &desc_loop_indexer);
745 v->Visit("block_iter_paddings", &block_iter_paddings);
746 }
747
748 static constexpr const char* _type_key = "tir.schedule.TensorizeInfo";
749 TVM_DECLARE_FINAL_OBJECT_INFO(TensorizeInfoNode, Object);
750};
751
752class TensorizeInfo : public ObjectRef {
753 public:
754 TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorizeInfo, ObjectRef, TensorizeInfoNode);
755};
756
757/*!
758 * \brief Establish a mapping between loops in a target block and an intrinsic description
759 * \param self The schedule state to be tensorized
760 * \param block_sref The target block to match against
761 * \param desc_func The prim func describing the computation to be tensorized
762 * \param allow_padding Whether to allow padding the block iters to match the intrinsic description
763 * \return TensorizeInfo structure if a valid mapping is found, NullOpt otherwise
764 */
765Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
766 const tir::StmtSRef& block_sref,
767 const tir::PrimFunc& desc_func, bool allow_padding);
768
769/*!\brief Necessary information used to perform transformations for tensorization */
770class AutoTensorizeMappingInfoNode : public Object {
771 public:
772 /*! \brief Possible mappings to apply to block iters */
773 Array<IndexMap> mappings;
774
775 /* Additional information from AutoTensorizeComparator */
776
777 /*! \brief Mapping from LHS buffer to RHS buffer */
778 Map<Buffer, Buffer> lhs_buffer_map;
779 /*! \brief Buffer indices on RHS */
780 Map<Buffer, Array<PrimExpr>> rhs_buffer_indices;
781 /*! \brief Block iters on LHS */
782 Array<IterVar> lhs_iters;
783 /*! \brief Block iters on RHS */
784 Array<IterVar> rhs_iters;
785
786 void VisitAttrs(AttrVisitor* v) {
787 v->Visit("mappings", &mappings);
788 v->Visit("lhs_buffer_map", &lhs_buffer_map);
789 v->Visit("rhs_buffer_indices", &rhs_buffer_indices);
790 v->Visit("lhs_iters", &lhs_iters);
791 v->Visit("rhs_iters", &rhs_iters);
792 }
793
794 static constexpr const char* _type_key = "tir.schedule.AutoTensorizeMappingInfo";
795 TVM_DECLARE_FINAL_OBJECT_INFO(AutoTensorizeMappingInfoNode, Object);
796};
797
798class AutoTensorizeMappingInfo : public ObjectRef {
799 public:
800 TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AutoTensorizeMappingInfo, ObjectRef,
801 AutoTensorizeMappingInfoNode);
802};
803
804/*!
805 * \brief Get mapping info between a target block and an intrinsic description including layout
806 * transformations to apply.
807 * \param self The schedule state
808 * \param block_sref The compute block for auto tensorization
809 * \param desc_func The prim func describing the computation to be tensorized
810 * \return AutoTensorizeMappingInfo structure if a potential mapping is found, NullOpt otherwise.
811 * \note Returning a valid AutoTensorizeMappingInfo doesn't guarantee the block can be tensorized.
812 * We will need to apply the suggested layout transformations and then match against the tensor
813 * intrinsics.
814 */
815Optional<AutoTensorizeMappingInfo> GetAutoTensorizeMappingInfo(const ScheduleState& self,
816 const StmtSRef& block_sref,
817 const PrimFunc& desc_func);
818
819/*!
820 * \brief Perform basic checks for auto tensorization applicability, such as the structure of
821 * arithmetic operations and data types.
822 * \param sch The schedule to be tensorized
823 * \param block_rv The compute block for auto tensorization
824 * \param desc_func The prim func describing the computation to be tensorized
825 * \return true if basic conditions are met.
826 */
827bool CheckAutoTensorizeApplicable(const tir::Schedule& sch, const tir::BlockRV& block_rv,
828 const tir::PrimFunc& desc_func);
829} // namespace tir
830} // namespace tvm
831
832#endif // TVM_TIR_SCHEDULE_ANALYSIS_H_
833