1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #ifndef TVM_TIR_SCHEDULE_ANALYSIS_H_ |
20 | #define TVM_TIR_SCHEDULE_ANALYSIS_H_ |
21 | |
22 | #include <tvm/arith/analyzer.h> |
23 | #include <tvm/ir/op.h> |
24 | #include <tvm/tir/index_map.h> |
25 | #include <tvm/tir/schedule/schedule.h> |
26 | #include <tvm/tir/schedule/state.h> |
27 | |
28 | #include <tuple> |
29 | #include <unordered_map> |
30 | #include <unordered_set> |
31 | #include <utility> |
32 | #include <vector> |
33 | |
34 | #include "../../runtime/thread_storage_scope.h" |
35 | |
36 | namespace tvm { |
37 | namespace tir { |
38 | |
39 | /******** Verification ********/ |
40 | /*! |
41 | * \brief Verifies the sref tree state is consistent with the IR |
42 | * \param self The schedule state containing the sref to be verified |
43 | * \throw An exception will be thrown if the sref tree is not valid |
44 | */ |
45 | void VerifySRefTree(const ScheduleState& self); |
46 | /*! |
47 | * \brief Verifies the cached flags in the schedule state, including: |
48 | * - affine_binding |
49 | * - region_cover |
50 | * - stage_pipeline |
51 | * \param self The schedule state to be verified |
52 | * \throw An exception will be thrown if some srefs are not valid |
53 | */ |
54 | void VerifyCachedFlags(const ScheduleState& self); |
55 | |
56 | /******** IR Module ********/ |
57 | /*! |
58 | * \brief Get PrimFunc and GlobalVar that the root block belongs to |
59 | * \param mod The IRModule |
60 | * \param root_block The root block of the PrimFunc |
61 | * \param result_g_var The result GlobalVar |
62 | * \return The result PrimFunc where the root block belongs to |
63 | * \note This function returns the pointer instead of ObjectRef to avoid later copy-on-write |
64 | */ |
65 | const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_block, |
66 | GlobalVar* result_g_var); |
67 | |
68 | /*! |
69 | * \brief Get the root node of the sref tree, which is the root block of the PrimFunc. |
70 | * \param sref The given sref. |
71 | * \return The root node of the sref tree which contains the given node. |
72 | */ |
73 | StmtSRef GetSRefTreeRoot(const StmtSRef& sref); |
74 | |
75 | /******** Scope ********/ |
76 | /*! |
77 | * \brief Checks if scope the specified sref is in is a stage-pipeline and return it |
78 | * \param self The schedule state |
79 | * \param sref The sref whose scope is to be checked |
80 | * \param require_stage_pipeline A boolean indicating whether to check stage pipeline |
81 | * \throw ScheduleError if |
82 | * 1) the sref has been the root of the AST (so it has no scope root), or |
83 | * 2) require_stage_pipeline = true, but its scope root is not a stage pipeline |
84 | * \return The block sref to the scope root |
85 | */ |
86 | StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref, bool require_stage_pipeline); |
87 | |
88 | /*! |
89 | * \brief The information of a block scope, including the leaf blocks, |
90 | * as well as the loop types (spatial, reduction) for each loop in the scope. |
91 | */ |
92 | struct ScopeBlockLoopInfo { |
93 | /*! \brief A list of the leaf blocks, from left to right */ |
94 | std::vector<BlockRealize> realizes; |
95 | /*! \brief The loop vars bound to spatial block iters */ |
96 | std::unordered_set<const VarNode*> spatial_vars; |
97 | /*! \brief The loop vars bound to non-spatial block iters */ |
98 | std::unordered_set<const VarNode*> non_spatial_vars; |
99 | }; |
100 | |
101 | /*! |
102 | * \brief Inspect the scope of the given sref |
103 | * \param scope_block The root block of the scope |
104 | * \return The information of the scope |
105 | */ |
106 | ScopeBlockLoopInfo GetScopeBlockLoopInfo(const Block& scope_block); |
107 | |
108 | /*! |
109 | * \brief Checks whether the block is a complete block under the scope |
110 | * \param self The schedule state |
111 | * \param block_sref The block to be checked |
112 | * \param scope_root_sref The sref to the root block of the scope that `block_sref` is in |
113 | * \return A boolean indicating if the block is a complete block |
114 | * \note Definition of a complete block: |
115 | * 1) All block vars are data parallel |
116 | * 2) Dominant: the block is the only writer of its output, |
117 | * dominating the reader of its output buffers |
118 | * 3) No overlap between the buffers the block reads and writes |
119 | */ |
120 | bool IsCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, |
121 | const StmtSRef& scope_root_sref); |
122 | |
123 | /*! |
124 | * \brief Check if the block is a complete block under the scope |
125 | * \param self The schedule state |
126 | * \param block_sref The sref to the block whose completeness is to be checked |
127 | * \param scope_root_sref The scope root of the block |
128 | * \throw ScheduleError If the block is not a complete block |
129 | */ |
130 | void CheckCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, |
131 | const StmtSRef& scope_root_sref); |
132 | |
133 | /*! |
134 | * \brief Check whether the block is a reduction block under the scope |
135 | * \param self The schedule state |
136 | * \param block_sref The block to be checked |
137 | * \param scope_root_sref The sref to the root block of the scope that `block_sref` is in |
138 | * \return A boolean indicating if the block is a reduction block |
139 | * \note Definition of a reduction block: |
140 | * 1) The block has the `init` statement |
141 | * 2) All the block bindings are quasi-affine expressions |
142 | * 3) All block vars are either data parallel block vars or reduction block vars |
143 | * 4) Dominant: the block is the only writer of its output, dominating the reader of its output |
144 | * buffers |
145 | * 5) The reduction block vars are not used to index the output buffers |
146 | */ |
147 | bool IsReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, |
148 | const StmtSRef& scope_root_sref); |
149 | |
150 | /*! |
151 | * \brief Check if the block is a reduction block under the scope |
152 | * \param self The schedule state |
153 | * \param block_sref The sref of the block to be checked |
154 | * \param scope_root_sref The scope root of the block |
155 | * \throw ScheduleError If the block is not a reduction block |
156 | */ |
157 | void CheckReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, |
158 | const StmtSRef& scope_root_sref); |
159 | |
160 | /*! |
161 | * \brief Check if the block is a complete block or a reduction block under the scope |
162 | * \param self The schedule state |
163 | * \param block_sref The sref of the block to be checked |
164 | * \param scope_root_sref The scope root of the block |
165 | * \throw ScheduleError If the block is neither a complete block nor a reduction block |
166 | */ |
167 | void CheckCompleteOrReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, |
168 | const StmtSRef& scope_root_sref); |
169 | |
170 | /*! |
171 | * \brief Check the subtree compact dataflow property. The scope root may have one or more subtrees |
172 | * rooted at its direct children, and this property requires all the blocks of the subtree |
173 | * that the specified sref is in to be local complete block or local reduction block. |
174 | * \param self The schedule state |
175 | * \param subtree_root The sref of the subtree root to be checked |
176 | */ |
177 | void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subtree_root); |
178 | /*! |
179 | * \brief Check if the block is an output block, i.e. the block writes to at least a buffer that is |
180 | * not allocated under the current scope |
181 | * \param self The schedule state |
182 | * \param block_sref The block to be checked |
183 | * \param scope_root_sref The scope root of the block |
184 | * \return A boolean flag indicating if the block is an output block |
185 | */ |
186 | bool IsOutputBlock(const ScheduleState& self, const StmtSRef& block_sref, |
187 | const StmtSRef& scope_root_sref); |
188 | |
189 | /*! |
190 | * \brief Check if the block is not an output block, i.e. all the buffers the block writes to |
191 | * are allocated under the current scope |
192 | * \param self The schedule state |
193 | * \param block_sref The block to be checked |
194 | * \param scope_root_sref The scope root of the block |
195 | * \throw ScheduleError if the block is an output block |
196 | */ |
197 | void CheckNotOutputBlock(const ScheduleState& self, const StmtSRef& block_sref, |
198 | const StmtSRef& scope_root_sref); |
199 | |
200 | /*! |
201 | * \brief Extracts the types of the block vars |
202 | * \param block_sref The block to be checked |
203 | * \return A vector of types of the block vars |
204 | */ |
205 | std::vector<IterVarType> GetBlockVarTypes(const StmtSRef& block_sref); |
206 | |
207 | /*! |
208 | * \brief Checks if a block could be considered as a "write cache" |
209 | * \param block_sref The block to be checked |
210 | * \return A boolean flag indicating if the block is a write cache |
211 | */ |
212 | bool IsWriteCache(const StmtSRef& block_sref); |
213 | |
214 | /******** Binding ********/ |
215 | /*! |
216 | * \brief Verifies if the block binding in a specific BlockRealize is an affine binding. |
217 | * The binding can be represented as an injective affine map from the loop iterators. |
218 | * \param realize The BlockRealize to be analyzed |
219 | * \param loop_var_ranges The ranges of the loop variables |
220 | * \param analyzer The analyzer |
221 | * \return A boolean flag indicating if the binding is affine |
222 | */ |
223 | bool IsAffineBinding(const BlockRealize& realize, const Map<Var, Range>& loop_var_ranges, |
224 | arith::Analyzer* analyzer); |
225 | |
226 | /*! |
227 | * \brief Check whether a block has an affine binding using the cached flag, and throw an exception |
228 | * if the block does not have an affine binding. |
229 | * \param self The schedule state |
230 | * \param block The block to be checked |
231 | * \throw ScheduleError If the input block does not have an affine binding |
232 | */ |
233 | void CheckAffineBinding(const ScheduleState& self, Block block); |
234 | |
235 | /*! |
236 | * \brief Check whether a block has an affine binding under the high exclusive sref node, |
237 | * throw an exception if the block does not have an affine binding. |
238 | * \param self The schedule state |
239 | * \param block The block to be checked |
240 | * \param high_exclusive The highest sref node |
241 | * \throw ScheduleError If the input block does not have an affine binding |
242 | */ |
243 | void CheckPartialAffineBinding(const ScheduleState& self, Block block, |
244 | const Optional<StmtSRef>& high_exclusive); |
245 | |
246 | /*! |
247 | * \brief Extracts the ranges of loop variables in a path of the sref tree |
248 | * \param low_inclusive The lowest node in the path |
249 | * \param high_exclusive The highest node in the path, defaults to the scope root if not specified |
250 | * \param extra_relax_scope If the scope is not global, the method will look beyond the limit and |
251 | * retrieve extra domains. For example, |
252 | * - if the storage scope is warp, it will look upwards for threadIdx.x |
253 | * - if the storage scope is shared, it will look for threadIdx.x/y/z |
254 | * \return The loop domain |
255 | */ |
256 | Map<Var, Range> LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive, |
257 | const Optional<StmtSRef>& high_exclusive = NullOpt, |
258 | const runtime::StorageScope& = // |
259 | runtime::StorageScope{runtime::StorageRank::kGlobal, "" }); |
260 | |
261 | /*! |
262 | * \brief Returns the block var binding |
263 | * \param realize The BlockRealize to be analyzed |
264 | * \return The block var binding |
265 | */ |
266 | Map<Var, PrimExpr> GetBindings(const BlockRealize& realize); |
267 | |
268 | /*! |
269 | * \brief Get the vars involved in the bindings of data parallel block vars and reduction block |
270 | * vars, respectively |
271 | * \param block_realize The BlockRealize to be analyzed |
272 | * \param data_par_vars The vars that appear in the binding of any data parallel block iter |
273 | * \param reduce_vars The vars that appear in the binding of any reduction block iter |
274 | * \return A boolean indicating whether the block has block iters that is neither a data parallel |
275 | * block iter nor a reduction block iter |
276 | */ |
277 | bool GetVarsTouchedByBlockIters(const BlockRealize& block_realize, |
278 | std::unordered_set<const VarNode*>* data_par_vars, |
279 | std::unordered_set<const VarNode*>* reduce_vars); |
280 | |
281 | /******** Loop properties ********/ |
282 | /*! |
283 | * \brief Check the loop starts with zero. |
284 | * \param self The schedule state |
285 | * \param loop_sref The StmtSRef that points to the loop to be checked |
286 | * \param analyzer The arithmetic analyzer |
287 | * \throw ScheduleError If the loop doesn't starts with zero. |
288 | */ |
289 | void CheckLoopStartsWithZero(const ScheduleState& self, const StmtSRef& loop_sref, |
290 | arith::Analyzer* analyzer); |
291 | |
292 | /*! |
293 | * \brief Check whether a block has a trivial binding, i.e. each block var is bound to a outer loop, |
294 | * from outer to inner. |
295 | * \param self The schedule state |
296 | * \param block_sref The block to be checked |
297 | * \throw ScheduleError If the block does not have trivial bindings |
298 | */ |
299 | void CheckBlockHasTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref); |
300 | |
301 | /******** Block-loop relation ********/ |
302 | |
303 | /*! |
304 | * \brief Gets StmtSRefs of leaf blocks of a scope where a specific block/loop is in |
305 | * \param self The schedule state |
306 | * \param parent_sref The StmtSRef that points to the parent block/loop |
307 | * \return A list of StmtSRefs of leaf block |
308 | */ |
309 | Array<StmtSRef> GetChildBlockSRefOnSRefTree(const ScheduleState& self, const StmtSRef& parent_sref); |
310 | |
311 | /*! |
312 | * \brief Gets the BlockRealize of the leaf blocks of a scope where a specific block/loop is in |
313 | * \param parent_sref The StmtSRef that points to the parent block/loop |
314 | * \return A list of leaf BlockRealize |
315 | */ |
316 | Array<BlockRealize> GetChildBlockRealizeOnSRefTree(const StmtSRef& parent_sref); |
317 | |
318 | /*! |
319 | * \brief Get the BlockRealize of the single child block of the block or loop specified by |
320 | * `parent_sref` on SRef tree, or throw an exception if there is 0 or multiple child blocks |
321 | * \param self The schedule state |
322 | * \param parent_sref The StmtSRef that points to the parent block/loop |
323 | * \return The BlockRealize of the single child block |
324 | * \throw ScheduleError If there is 0 or multiple child blocks |
325 | */ |
326 | BlockRealize CheckGetSingleChildBlockRealizeOnSRefTree(const ScheduleState& self, |
327 | const StmtSRef& parent_sref); |
328 | |
329 | /*! |
330 | * \brief Get the BlockRealize of the input block |
331 | * \param self The schedule state |
332 | * \param block_sref The StmtSRef of the queried block |
333 | * \return The BlockRealize of the input block |
334 | */ |
335 | BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sref); |
336 | |
337 | /*! |
338 | * \brief Get the IterVarType of the specific loop, according to the blocks it's bound to |
339 | * \param loop_sref The loop to be checked |
340 | * \return The IterVarType of the specific loop |
341 | */ |
342 | IterVarType GetLoopIterType(const StmtSRef& loop_sref); |
343 | |
344 | /*! |
345 | * \brief Get the lowest common ancestor of an array of blocks or loops on the sref tree |
346 | * \param srefs The block srefs or loop srefs whose lowest common ancestor is to be queried |
347 | * \return The lowest common ancestor of the input block srefs or loop srefs |
348 | * \note The input array is required to have at least one sref |
349 | */ |
350 | StmtSRef GetSRefLowestCommonAncestor(const Array<StmtSRef>& srefs); |
351 | |
352 | /*! |
353 | * \brief Checks if the given block has been applied by multi-level tiling. We check this by |
354 | * examine the block's annotation. |
355 | * \param block_sref The block to be checked |
356 | * \return A boolean indicating whether the block has been multi-level tiled. |
357 | */ |
358 | bool HasBeenMultiLevelTiled(const StmtSRef& block_sref); |
359 | |
360 | /*! |
361 | * \brief Collect all the feasible compute-at locations of the input block |
362 | * \param self The schedule state |
363 | * \param block_sref The block whose compute-at locations are to be collected |
364 | * \return All the feasible compute-at locations of the input block, given as an array of loop srefs |
365 | * and an array of their indices among the outer loops of the input block |
366 | */ |
367 | std::pair<Array<StmtSRef>, std::vector<int>> CollectComputeLocation(const ScheduleState& self, |
368 | const StmtSRef& block_sref); |
369 | |
370 | /******** Producer-consumer relation ********/ |
371 | |
372 | /*! |
373 | * \brief Get the producer blocks to the given block under the given scope |
374 | * \param block_sref The block whose producers are to be retrieved |
375 | * \param scope The block scope where the given block is in |
376 | * \return The producer blocks of the specified block |
377 | */ |
378 | Array<StmtSRef> GetProducers(const StmtSRef& block_sref, const BlockScope& scope); |
379 | |
380 | /*! |
381 | * \brief Get the consumer blocks to the given block under the given scope |
382 | * \param block_sref The block whose consumers are to be retrieved |
383 | * \param scope The block scope where the given block is in |
384 | * \return The consumer blocks of the specified block |
385 | */ |
386 | Array<StmtSRef> GetConsumers(const StmtSRef& block_sref, const BlockScope& scope); |
387 | |
388 | /*! |
389 | * \brief A solution to split a ordered list of subtrees into two parts, |
390 | * where producers are on the LHS and consumers are on the RHS. |
391 | * For example, subtree[0, 3) are on the LHS, and subtree[3, 6) are on the RHS. |
392 | */ |
393 | struct ProducerConsumerSplit { |
394 | /*! \brief Indicates that all producers fall into `subtrees[0, last_producer_position]` */ |
395 | int last_producer_position; |
396 | /*! \brief Indicates that all consumers fall into `subtrees[first_consumer_position, ...)` */ |
397 | int first_consumer_position; |
398 | /*! \brief The number of given producers visited in `subtrees` */ |
399 | int n_producers_visited; |
400 | /*! \brief The number of given consumers visited in `subtrees` */ |
401 | int n_consumers_visited; |
402 | /*! |
403 | * \brief Find a split among the given `subtree` |
404 | * \param state The schedule state |
405 | * \param subtrees The ordered list of subtrees to be split |
406 | * \param producer_block_srefs The producers |
407 | * \param consumer_block_srefs The consumers |
408 | * \param block2realize If not null, the corresponding BlockRealize to each block in the scope |
409 | * will be saved in this map |
410 | * \return The valid split points are (last_producer_position, first_consumer_position] |
411 | * \throw ScheduleError is not valid split is found |
412 | */ |
413 | static ProducerConsumerSplit Find( |
414 | const ScheduleState& state, const Array<Stmt>& subtrees, |
415 | const Array<StmtSRef>& producer_block_srefs, const Array<StmtSRef>& consumer_block_srefs, |
416 | std::unordered_map<const BlockNode*, const BlockRealizeNode*>* block2realize); |
417 | }; |
418 | |
419 | /******** Block-buffer relation ********/ |
420 | |
421 | /*! |
422 | * \brief Get the n-th read or write buffer of the given block. |
423 | * \param self The schedule state. |
424 | * \param block The queried block. |
425 | * \param n The index of the queried buffer. |
426 | * \param index_type The type of the buffer index, kRead or kWrite. |
427 | * \return The buffer of the n-th read/write region of the block. |
428 | * \throw ScheduleError If the buffer index is out of bound. |
429 | */ |
430 | Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n, |
431 | BufferIndexType index_type); |
432 | |
433 | /*! |
434 | * \brief Get the n-th read or write buffer of the given block. |
435 | * \param self The schedule state. |
436 | * \param block The queried block. |
437 | * \param n The index of the queried buffer. |
438 | * \param index_type The type of the buffer index, kRead or kWrite. |
439 | * \return The n-th read/write region of the block. |
440 | * \throw ScheduleError If the buffer index is out of bound. |
441 | */ |
442 | BufferRegion GetNthAccessBufferRegion(const ScheduleState& self, const Block& block, int n, |
443 | BufferIndexType index_type); |
444 | |
445 | /*! |
446 | * \brief Find the defining site of the buffer in the given block and its ancestors |
447 | * \param block_sref The block sref |
448 | * \param buffer The buffer |
449 | * \return The defining site of the buffer and whether the buffer is allocated (otherwise the |
450 | * buffer is from match_buffer). |
451 | */ |
452 | std::pair<Optional<StmtSRef>, bool> GetBufferDefiningSite(const StmtSRef& block_sref, |
453 | const Buffer& buffer); |
454 | |
455 | /******** Reduction Block Related ********/ |
456 | |
457 | /*! |
458 | * \brief Get the init values and the BufferStore updates from the input reduction block |
459 | * \param self The schedule state, used for error reporting |
460 | * \param block The block from which the init values and BufferStore updates are extracted from |
461 | * \return The extracted init values and BufferStore updates |
462 | * \throw ScheduleError If rfactor or cross-thread reduction cannot be applied to the block |
463 | */ |
464 | std::pair<Array<PrimExpr>, Array<BufferStore>> GetInitValuesAndUpdatesFromReductionBlock( |
465 | const Optional<ScheduleState>& self, Block block); |
466 | |
467 | /*! |
468 | * \brief Check whether the input array of IterVars only contains data-parallel and reduction block |
469 | * iters |
470 | * \param iters The input array of IterVars to be checked |
471 | * \return A boolean indicating whether the input array of IterVars only contains data-parallel and |
472 | * reduction block iters |
473 | */ |
474 | bool ContainsOnlyDataParAndReductionBlockIter(const Array<IterVar>& iters); |
475 | |
476 | /*! |
477 | * \brief Check whether the block's reduction block iters are not used to index the block's output |
478 | * buffers |
479 | * \param block The block to be checked |
480 | * \return A boolean indicating whether the block's reduction block iters are not used to index the |
481 | * block's output buffer |
482 | */ |
483 | bool ReductionIterNotIndexOutputBuffer(const Block& block); |
484 | |
485 | /*! |
486 | * \brief Given a list of reduction identities and a list of reduction combiners, detect the |
487 | * corresponding commutative reducer, and extract the combiner LHS values and combiner RHS values |
488 | * \param self The schedule state |
489 | * \param identities The reduction identities to be analyzed |
490 | * \param combiners The reduction combiners to be analyzed |
491 | * \return The corresponding CommReducer, combiner LHS values and combiner RHS values |
492 | * \throw ScheduleError If no corresponding commutative reducer can be matched |
493 | */ |
494 | std::tuple<CommReducer, Array<PrimExpr>, Array<PrimExpr>> GetReducerAndCombinerLhsRhs( |
495 | const Optional<ScheduleState>& self, const Array<PrimExpr>& identities, |
496 | const Array<BufferStore>& combiners); |
497 | |
498 | /******** Commutative Reducer ********/ |
499 | |
500 | /*! |
501 | * \brief Get the list of the registered reducer-getter functions |
502 | * \return The list of the registered reducer-getter functions |
503 | * \sa ReducerRegistry |
504 | */ |
505 | std::vector<runtime::TypedPackedFunc<Optional<CommReducer>(Array<PrimExpr>)>> GetReducerGetters(); |
506 | |
507 | /*! |
508 | * \brief Given the input identities and the combiner BufferStores of a reduction, extract the |
509 | * corresponding commutative reducer, LHS values and RHS values, if possible. |
510 | * \param identities The identities of the reduction |
511 | * \param combiners The combiners of the reduction |
512 | * \param result_reducer The extracted CommReducer |
513 | * \param lhs The extracted LHS values of the reducer |
514 | * \param rhs The extracted RHS values of the reducer |
515 | * \return A boolean indicating whether a corresponding commutative reducer is found |
516 | */ |
517 | bool FromIdentityCombiner(const Array<PrimExpr>& identities, const Array<BufferStore>& combiners, |
518 | CommReducer* result_reducer, Array<PrimExpr>* lhs, Array<PrimExpr>* rhs); |
519 | |
520 | /******** Misc ********/ |
521 | |
522 | /*! |
523 | * \brief Check whether the input storage scope string is valid. Throw an error if not. |
524 | * \param self The schedule state |
525 | * \param storage_scope The storage scope string to be checked |
526 | * \throw ScheduleError If the input storage scope is not valid |
527 | */ |
528 | void CheckStorageScope(const ScheduleState& self, String storage_scope); |
529 | |
530 | /*! |
531 | * \brief Checks if a block could be successfully computed inline into its consumer |
532 | * \param self The schedule state |
533 | * \param block_sref The block to be checked |
534 | * \return A boolean indicating whether the block could be successfully computed inline |
535 | */ |
536 | bool CanComputeInline(const ScheduleState& self, const StmtSRef& block_sref); |
537 | |
538 | /*! |
539 | * \brief Checks if a block could be successfully computed inline into its producer |
540 | * \param self The schedule state |
541 | * \param block_sref The block to be checked |
542 | * \return A boolean indicating whether the block could be successfully computed inline |
543 | */ |
544 | bool CanReverseComputeInline(const ScheduleState& self, const StmtSRef& block_sref); |
545 | |
546 | /*! |
547 | * \brief Checks if a producer block could be successfully computed at the specific loop. |
548 | * \param self The schedule state |
549 | * \param block_sref The block to be moved |
550 | * \param loop_sref The loop where the block to be moved to |
551 | * \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1 |
552 | * \return A boolean indicating whether the block could be successfully compute at the specific loop |
553 | */ |
554 | bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& loop_sref, |
555 | bool preserve_unit_loops); |
556 | |
557 | /*! |
558 | * \brief Checks if a consumer block could be successfully computed at the specific loop. |
559 | * \param self The schedule state |
560 | * \param block_sref The block to be moved |
561 | * \param loop_sref The loop where the block to be moved to |
562 | * \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1 |
563 | * \return A boolean indicating whether the block could be successfully reverse compute at the |
564 | * specific loop |
565 | */ |
566 | bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref, |
567 | const StmtSRef& loop_sref, bool preserve_unit_loops); |
568 | |
569 | /*! |
570 | * \brief Provided the access pattern to a buffer, suggest one of the possible layout |
571 | * transformation to minimize the locality of the access pattern. |
572 | * \param buffer The buffer to be transformed |
573 | * \param indices The access pattern to the buffer |
574 | * \param loops The loops above the buffer |
575 | * \param predicate The predicate of the access |
576 | * \param analyzer Arithmetic analyzer |
577 | */ |
578 | Optional<IndexMap> SuggestIndexMap(const Buffer& buffer, const Array<PrimExpr>& indices, |
579 | const Array<For>& loops, const PrimExpr& predicate, |
580 | arith::Analyzer* analyzer); |
581 | |
582 | /*! |
583 | * \brief Checks if the given AST contains the specific operators |
584 | * \param stmt The AST statement to be checked |
585 | * \param ops The list of operators to be checked |
586 | * \return A boolean indicating whether the AST contains the specific operators |
587 | */ |
588 | bool HasOp(const Stmt& stmt, const Array<Op>& ops); |
589 | |
590 | /*! |
591 | * \brief Checks if the given AST statement contains if-then-else, including |
592 | * 1) IfThenElse statement |
593 | * 2) Select expression |
594 | * 3) The operator `tir.if_then_else` |
595 | * 4) non-constant-true Block predicates |
596 | * \param stmt The AST statement to be checked |
597 | * \return A boolean indicating whether the statement contains the if-then-else pattern |
598 | */ |
599 | bool HasIfThenElse(const Stmt& stmt); |
600 | |
601 | /*! |
602 | * \brief Given the read/write region, extract the pattern of their index correspondence |
603 | * namely, the mapping from read index to the write index. |
604 | * \param read_region The read region |
605 | * \param write_region The write region |
606 | * \return A tuple of booleans, the extracted pattern |
607 | * 0) exists: if the pattern is found |
608 | * 1) surjective: if the pattern is surjective, i.e. each write index is mapped at least once |
609 | * e.g. A[i, j] = B[i, i, j] |
610 | * 2) injective: if the pattern is injective, i.e. each write index is mapped at most once. |
611 | * e.g. A[i, j] = B[i] |
612 | * 3) ordered: if the mapping is ordered |
613 | * 4) no_const_read: if there is no constant indexing in the read indices, |
614 | * e.g. A[i, j] = B[0, i, j] |
615 | * 5) no_shift_read: if there is no constant shift in the read indices, |
616 | * e.g. A[i, j] = B[i + 1, j] |
617 | */ |
618 | std::tuple</*exists=*/bool, |
619 | /*surjective=*/bool, |
620 | /*injective=*/bool, |
621 | /*ordered=*/bool, |
622 | /*no_const_read=*/bool, |
623 | /*no_shift_read=*/bool> |
624 | AnalyzeReadWritePattern(const BufferRegion& read_region, const BufferRegion& write_region); |
625 | |
626 | /*! |
627 | * \brief Check if the block is a data parallel block, i.e. all the block vars are data parallel |
628 | * \param block_sref The block to be checked |
629 | * \return A boolean flag indicating if the block is a data parallel block |
630 | */ |
631 | bool IsSpatial(const StmtSRef& block_sref); |
632 | |
633 | /*! |
634 | * \brief Check whether a block has a trivial binding, i.e. each block var is bound to a outer loop, |
635 | * from outer to inner. |
636 | * \param self The schedule state |
637 | * \param block_sref The block to be checked |
638 | * \return A boolean flag indicating if the block has a trivial binding |
639 | */ |
640 | bool IsTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref); |
641 | |
642 | /*! |
643 | * \brief Checks if the given block has data reuse opportunity and thus multi-level tiling is |
644 | * beneficial. |
645 | * \param self The schedule state |
646 | * \param block_sref The block to be checked |
647 | * \return A boolean indicating whether the block has data reuse opportunity |
648 | */ |
649 | bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef& block_sref); |
650 | |
651 | /*! |
652 | * \brief Checks if all the blocks in the PrimFunc is spatial |
653 | * \param func The PrimFunc to be checked |
654 | * \return A boolean indicating whether all the blocks in the PrimFunc is spatial |
655 | */ |
656 | bool IsSpatialPrimFunc(const PrimFunc& func); |
657 | |
658 | /*! |
659 | * \brief Checks if the rfactor or cross thread reduction is beneficial to the given block. |
660 | * \param self The schedule state. |
661 | * \param block_sref The block to be checked. |
662 | * \param max_parallel_extent The maximum parallel jobs on the target. |
663 | * \param max_parallel_basic The maximum cores on the target. |
664 | * \return A boolean indicating whether the operation is beneficial. |
665 | */ |
666 | bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // |
667 | const tir::StmtSRef& block_sref, // |
668 | int64_t max_parallel_extent, // |
669 | int64_t max_parallel_basic); |
670 | |
671 | /*! |
672 | * \brief Analyze the buffer region under the sref tree path [dom_low_inclusive, dom_high_exclusive) |
673 | * Relaxation of the region may be used in upper-bound analysis, i.e. some extra region may be added |
674 | * to the result. |
675 | * \param region The buffer region to be analyzed |
676 | * \param dom_low_inclusive The lowest node in the sref tree path |
677 | * \param dom_high_exclusive The highest node in the sref tree path |
678 | * \return An n-dimensional integer set |
679 | */ |
680 | Array<arith::IntSet> AnalyzeRegionUpperBound(const BufferRegion& region, const PrimExpr& predicate, |
681 | const StmtSRef& dom_low_inclusive, |
682 | const StmtSRef& dom_high_exclusive, |
683 | arith::Analyzer* analyzer); |
684 | |
685 | /*! |
686 | * \brief Analyze the buffer region under the sref tree path [dom_low_inclusive, dom_high_exclusive) |
687 | * Some subregion may be discarded during the lower-bound analysis. |
688 | * \param realize The block realize that touches the buffer region |
689 | * \param region The buffer region to be analyzed |
690 | * \param dom_low_inclusive The lowest node in the sref tree path |
691 | * \param dom_high_exclusive The highest node in the sref tree path |
692 | * \param analyzer The analyzer |
693 | * \return An n-dimensional integer set |
694 | */ |
695 | Array<arith::IntSet> AnalyzeRegionLowerBound(const BufferRegion& region, const PrimExpr& predicate, |
696 | const StmtSRef& dom_low_inclusive, |
697 | const StmtSRef& dom_high_exclusive, |
698 | arith::Analyzer* analyzer); |
699 | |
700 | /*! |
701 | * \brief Check if buffer indices are all Vars and extr |
702 | * \param buffer_access The BufferLoad or BufferStore |
703 | * \return The indices if the indices are all Vars, otherwise NullOpt |
704 | */ |
705 | template <typename T> |
706 | Optional<Array<Var>> CheckTrivialBufferIndices(const T& buffer_access) { |
707 | Array<Var> indices; |
708 | for (const PrimExpr& index : buffer_access->indices) { |
709 | const VarNode* var = index.as<VarNode>(); |
710 | if (var == nullptr) { |
711 | return NullOpt; |
712 | } |
713 | indices.push_back(GetRef<Var>(var)); |
714 | } |
715 | return indices; |
716 | } |
717 | |
718 | /*! |
719 | * \brief Simplify non-trivial expressions |
720 | * \param expr The expression to be simplified |
721 | * \param analyzer The analyzer |
722 | * \return The simplified expression |
723 | * |
724 | * During scheduling, we often need preserve block iters in trivial expressions that can be |
725 | * simplified to constant values for further scheduling and analysis because simplifing away the |
726 | * block iters may result in loss of information for further analysis. |
727 | */ |
728 | PrimExpr SimplifyNonTrivialExpr(const PrimExpr& expr, arith::Analyzer* analyzer); |
729 | |
730 | /*! \brief Necessary information used for tensorization */ |
731 | class TensorizeInfoNode : public Object { |
732 | public: |
733 | /*! \brief Maps loops in a target block to the ones in an intrinsic description */ |
734 | Map<tir::StmtSRef, tir::For> loop_map; |
735 | /*! \brief Maps loops in an intrinsic description to its index, outer to inner */ |
736 | Map<tir::For, Integer> desc_loop_indexer; |
737 | /*! \brief Optional padded extents of the block iters when padding is needed to match the |
738 | * intrinsic description |
739 | */ |
740 | Optional<Array<Integer>> block_iter_paddings; |
741 | |
742 | void VisitAttrs(AttrVisitor* v) { |
743 | v->Visit("loop_map" , &loop_map); |
744 | v->Visit("desc_loop_indexer" , &desc_loop_indexer); |
745 | v->Visit("block_iter_paddings" , &block_iter_paddings); |
746 | } |
747 | |
748 | static constexpr const char* _type_key = "tir.schedule.TensorizeInfo" ; |
749 | TVM_DECLARE_FINAL_OBJECT_INFO(TensorizeInfoNode, Object); |
750 | }; |
751 | |
752 | class TensorizeInfo : public ObjectRef { |
753 | public: |
754 | TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorizeInfo, ObjectRef, TensorizeInfoNode); |
755 | }; |
756 | |
757 | /*! |
758 | * \brief Establish a mapping between loops in a target block and an intrinsic description |
759 | * \param self The schedule state to be tensorized |
760 | * \param block_sref The target block to match against |
761 | * \param desc_func The prim func describing the computation to be tensorized |
762 | * \param allow_padding Whether to allow padding the block iters to match the intrinsic description |
763 | * \return TensorizeInfo structure if a valid mapping is found, NullOpt otherwise |
764 | */ |
765 | Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self, |
766 | const tir::StmtSRef& block_sref, |
767 | const tir::PrimFunc& desc_func, bool allow_padding); |
768 | |
769 | /*!\brief Necessary information used to perform transformations for tensorization */ |
770 | class AutoTensorizeMappingInfoNode : public Object { |
771 | public: |
772 | /*! \brief Possible mappings to apply to block iters */ |
773 | Array<IndexMap> mappings; |
774 | |
775 | /* Additional information from AutoTensorizeComparator */ |
776 | |
777 | /*! \brief Mapping from LHS buffer to RHS buffer */ |
778 | Map<Buffer, Buffer> lhs_buffer_map; |
779 | /*! \brief Buffer indices on RHS */ |
780 | Map<Buffer, Array<PrimExpr>> rhs_buffer_indices; |
781 | /*! \brief Block iters on LHS */ |
782 | Array<IterVar> lhs_iters; |
783 | /*! \brief Block iters on RHS */ |
784 | Array<IterVar> rhs_iters; |
785 | |
786 | void VisitAttrs(AttrVisitor* v) { |
787 | v->Visit("mappings" , &mappings); |
788 | v->Visit("lhs_buffer_map" , &lhs_buffer_map); |
789 | v->Visit("rhs_buffer_indices" , &rhs_buffer_indices); |
790 | v->Visit("lhs_iters" , &lhs_iters); |
791 | v->Visit("rhs_iters" , &rhs_iters); |
792 | } |
793 | |
794 | static constexpr const char* _type_key = "tir.schedule.AutoTensorizeMappingInfo" ; |
795 | TVM_DECLARE_FINAL_OBJECT_INFO(AutoTensorizeMappingInfoNode, Object); |
796 | }; |
797 | |
798 | class AutoTensorizeMappingInfo : public ObjectRef { |
799 | public: |
800 | TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AutoTensorizeMappingInfo, ObjectRef, |
801 | AutoTensorizeMappingInfoNode); |
802 | }; |
803 | |
804 | /*! |
805 | * \brief Get mapping info between a target block and an intrinsic description including layout |
806 | * transformations to apply. |
807 | * \param self The schedule state |
808 | * \param block_sref The compute block for auto tensorization |
809 | * \param desc_func The prim func describing the computation to be tensorized |
810 | * \return AutoTensorizeMappingInfo structure if a potential mapping is found, NullOpt otherwise. |
811 | * \note Returning a valid AutoTensorizeMappingInfo doesn't guarantee the block can be tensorized. |
812 | * We will need to apply the suggested layout transformations and then match against the tensor |
813 | * intrinsics. |
814 | */ |
815 | Optional<AutoTensorizeMappingInfo> GetAutoTensorizeMappingInfo(const ScheduleState& self, |
816 | const StmtSRef& block_sref, |
817 | const PrimFunc& desc_func); |
818 | |
819 | /*! |
820 | * \brief Perform basic checks for auto tensorization applicability, such as the structure of |
821 | * arithmetic operations and data types. |
822 | * \param sch The schedule to be tensorized |
823 | * \param block_rv The compute block for auto tensorization |
824 | * \param desc_func The prim func describing the computation to be tensorized |
825 | * \return true if basic conditions are met. |
826 | */ |
827 | bool CheckAutoTensorizeApplicable(const tir::Schedule& sch, const tir::BlockRV& block_rv, |
828 | const tir::PrimFunc& desc_func); |
829 | } // namespace tir |
830 | } // namespace tvm |
831 | |
832 | #endif // TVM_TIR_SCHEDULE_ANALYSIS_H_ |
833 | |