1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #ifndef TVM_TIR_SCHEDULE_SCHEDULE_H_ |
20 | #define TVM_TIR_SCHEDULE_SCHEDULE_H_ |
21 | |
22 | #include <tvm/support/random_engine.h> |
23 | #include <tvm/tir/index_map.h> |
24 | #include <tvm/tir/schedule/state.h> |
25 | #include <tvm/tir/schedule/trace.h> |
26 | |
27 | namespace tvm { |
28 | namespace tir { |
29 | |
30 | /*! \brief The level of detailed error message rendering */ |
31 | enum class ScheduleErrorRenderLevel : int32_t { |
32 | /*! \brief Render a detailed error message */ |
33 | kDetail = 0, |
34 | /*! \brief Render the error in fast mode */ |
35 | kFast = 1, |
36 | /*! \brief No error message at all */ |
37 | kNone = 2, |
38 | }; |
39 | |
40 | /*! \brief Type of buffer index */ |
41 | enum class BufferIndexType : int32_t { |
42 | /*! \brief Index of a read buffer */ |
43 | kRead = 0, |
44 | /*! \brief Index of a written buffer */ |
45 | kWrite = 1, |
46 | }; |
47 | |
48 | /**************** Random variable: BlockRV ****************/ |
49 | |
50 | /*! \brief A random variable that evaluates to a TensorIR block */ |
51 | class BlockRVNode : public runtime::Object { |
52 | public: |
53 | void VisitAttrs(tvm::AttrVisitor* v) {} |
54 | static constexpr const char* _type_key = "tir.BlockRV" ; |
55 | TVM_DECLARE_FINAL_OBJECT_INFO(BlockRVNode, runtime::Object); |
56 | }; |
57 | |
58 | /*! |
59 | * \brief Managed reference to BlockRVNode |
60 | * \sa BlockRVNode |
61 | */ |
62 | class BlockRV : public runtime::ObjectRef { |
63 | public: |
64 | /*! \brief Constructor */ |
65 | TVM_DLL BlockRV(); |
66 | TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BlockRV, runtime::ObjectRef, BlockRVNode); |
67 | }; |
68 | |
69 | /**************** Random variable: LoopRV ****************/ |
70 | |
71 | /*! \brief A random variable that evaluates to a TensorIR for loop */ |
72 | class LoopRVNode : public runtime::Object { |
73 | public: |
74 | void VisitAttrs(tvm::AttrVisitor* v) {} |
75 | static constexpr const char* _type_key = "tir.LoopRV" ; |
76 | TVM_DECLARE_FINAL_OBJECT_INFO(LoopRVNode, runtime::Object); |
77 | }; |
78 | |
79 | /*! |
80 | * \brief Managed reference to LoopRVNode |
81 | * \sa LoopRVNode |
82 | */ |
83 | class LoopRV : public runtime::ObjectRef { |
84 | public: |
85 | /*! \brief Constructor */ |
86 | TVM_DLL LoopRV(); |
87 | TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LoopRV, runtime::ObjectRef, LoopRVNode); |
88 | }; |
89 | |
90 | /**************** Random variable: ExprRV ****************/ |
91 | |
92 | /*! \brief An expr random variable */ |
93 | using ExprRV = PrimExpr; |
94 | |
95 | using ExprRVNode = PrimExprNode; |
96 | |
97 | /**************** The Schedule class ****************/ |
98 | |
99 | class Schedule; |
100 | |
101 | /*! \brief The user-facing schedule class */ |
102 | class ScheduleNode : public runtime::Object { |
103 | friend class Schedule; |
104 | |
105 | public: |
106 | virtual ~ScheduleNode() = default; |
107 | |
108 | static constexpr const char* _type_key = "tir.Schedule" ; |
109 | TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleNode, runtime::Object); |
110 | |
111 | public: |
112 | /*! \brief Get the IRModule associated with this schedule. */ |
113 | virtual IRModule mod() const { return state()->mod; } |
114 | /*! \return The internal state of scheduling */ |
115 | virtual ScheduleState state() const = 0; |
116 | /*! \return The internally maintained trace of scheduling program execution */ |
117 | virtual Optional<Trace> trace() const = 0; |
118 | /*! |
119 | * \brief Instruct the schedule to work on a function in the IRModule. |
120 | * |
121 | * By default, the schedule works on the function with the name "main", or the only function in |
122 | * the IRModule if there is only one. If there is multiple functions in the IRModule, and none of |
123 | * their names are "main", users will have to call this method to explicitly specify which |
124 | * function to work on. |
125 | * |
126 | * This sugar function will guide the `GetBlock` method if its `func_name` is not specified. |
127 | * |
128 | * \param func_name The name of the function to be working on |
129 | * |
130 | * \sa GetBlock |
131 | */ |
132 | virtual void WorkOn(const String& func_name) = 0; |
133 | /*! |
134 | * \brief Returns a copy of the schedule, including both its state and its symbol table, |
135 | * guaranteeing that |
136 | * 1) SRef tree is completely reconstructed; |
137 | * 2) The IRModule being scheduled is not modified; |
138 | * 3) All the random variables are valid in the copy, pointing to the corresponding sref |
139 | * reconstructed |
140 | */ |
141 | virtual Schedule Copy() = 0; |
142 | /*! |
143 | * \brief Seed the randomness |
144 | * \param seed The new random seed, -1 if use device random, otherwise non-negative |
145 | */ |
146 | virtual void Seed(support::LinearCongruentialEngine::TRandState seed) = 0; |
147 | /*! \brief Fork the random state */ |
148 | virtual support::LinearCongruentialEngine::TRandState ForkSeed() = 0; |
149 | |
150 | public: |
151 | /******** Lookup/Remove random variables ********/ |
152 | /*! |
153 | * \brief Get the block corresponding to the specific BlockRV |
154 | * \param block_rv The BlockRV to be looked up |
155 | * \return The corresponding block |
156 | */ |
157 | virtual Block Get(const BlockRV& block_rv) const = 0; |
158 | /*! |
159 | * \brief Get the for loop corresponding to the specific LoopRV |
160 | * \param loop_rv The LoopRV to be looked up |
161 | * \return The corresponding for loop |
162 | */ |
163 | virtual For Get(const LoopRV& loop_rv) const = 0; |
164 | /*! |
165 | * \brief Get the expr corresponding to the specific random variable |
166 | * \param expr_rv The random variable to be looked up |
167 | * \return The corresponding expr |
168 | */ |
169 | virtual PrimExpr Get(const ExprRV& expr_rv) const = 0; |
170 | /*! |
171 | * \brief Get the block sref corresponding to the specific BlockRV |
172 | * \param block_rv The BlockRV to be looked up |
173 | * \return The corresponding block sref |
174 | */ |
175 | virtual StmtSRef GetSRef(const BlockRV& block_rv) const = 0; |
176 | /*! |
177 | * \brief Get the loop sref corresponding to the specific LoopRV |
178 | * \param loop_rv The LoopRV to be looked up |
179 | * \return The corresponding loop sref |
180 | */ |
181 | virtual StmtSRef GetSRef(const LoopRV& loop_rv) const = 0; |
182 | /*! |
183 | * \brief Check the existance of a specific BlockRV |
184 | * \param block_rv The BlockRV to be looked up |
185 | * \return Whether the corresponding block exists |
186 | */ |
187 | virtual bool HasBlock(const BlockRV& block_rv) const = 0; |
188 | /*! |
189 | * \brief Get the block/loop sref corresponding to the specific statement |
190 | * \param stmt The statement to be looked up |
191 | * \return The corresponding block/loop sref |
192 | */ |
193 | virtual StmtSRef GetSRef(const StmtNode* stmt) const; |
194 | /*! |
195 | * \brief Get the block/loop sref corresponding to the specific statement |
196 | * \param stmt The statement to be looked up |
197 | * \return The corresponding block/loop sref |
198 | */ |
199 | StmtSRef GetSRef(const Stmt& stmt) const { return this->GetSRef(stmt.get()); } |
200 | /*! |
201 | * \brief Remove a block random variable from the symbol table |
202 | * \param block_rv The random variable to be removed |
203 | */ |
204 | virtual void RemoveRV(const BlockRV& block_rv) = 0; |
205 | /*! |
206 | * \brief Remove a loop random variable from the symbol table |
207 | * \param loop_rv The random variable to be removed |
208 | */ |
209 | virtual void RemoveRV(const LoopRV& loop_rv) = 0; |
210 | /*! |
211 | * \brief Remove an integer random variable from the symbol table |
212 | * \param expr_rv The random variable to be removed |
213 | */ |
214 | virtual void RemoveRV(const ExprRV& expr_rv) = 0; |
215 | |
216 | public: |
217 | /******** Schedule: Sampling ********/ |
218 | /*! |
219 | * \brief Sample an integer given the probability distribution |
220 | * \param candidates The candidates |
221 | * \param probs The probability distribution of the candidates |
222 | * \param decision The sampling decision |
223 | * \return The random variable sampled from candidates |
224 | */ |
225 | virtual ExprRV SampleCategorical(const Array<Integer>& candidates, const Array<FloatImm>& probs, |
226 | Optional<Integer> decision = NullOpt) = 0; |
227 | /*! |
228 | * \brief Sample the factors to perfect tile a specific loop |
229 | * \param loop_rv The loop to be tiled |
230 | * \param n The number of tiles to be sampled |
231 | * \param max_innermost_factor The maximum tile size allowed to be sampled in the innermost loop |
232 | * \param decision The sampling decision |
233 | * \return A list of length `n`, the random perfect tile sizes sampled |
234 | */ |
235 | virtual Array<ExprRV> SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, |
236 | Optional<Array<Integer>> decision = NullOpt) = 0; |
237 | /*! |
238 | * \brief Sample a compute-at location of the given block |
239 | * \param block_rv The block whose compute-at location is to be sampled |
240 | * \param decision The sampling decision |
241 | * \return The sampled loop where the input block is to be computed at |
242 | */ |
243 | virtual LoopRV SampleComputeLocation(const BlockRV& block_rv, |
244 | Optional<Integer> decision = NullOpt) = 0; |
245 | |
246 | /******** Schedule: Get blocks & loops ********/ |
247 | /*! |
248 | * \brief Retrieve a block in a specific function with its name |
249 | * |
250 | * By default, if `func_name` is not specified, the schedule will search for the block in the |
251 | * function that is currently being "worked on". To switch the function to be worked on, use |
252 | * `WorkOn` before calling this method. |
253 | * |
254 | * \param name The name of the block to be retrieved |
255 | * \param func_name The name of the function |
256 | * \return The block retrieved |
257 | * \note Indexing error is raised if 0 or multiple blocks exist with the specific name |
258 | * |
259 | * \sa WorkOn |
260 | */ |
261 | virtual BlockRV GetBlock(const String& name, const Optional<String>& func_name = NullOpt) = 0; |
262 | /*! |
263 | * \brief Get the parent loops of the block in its scope, from outer to inner |
264 | * \param block_rv The query block |
265 | * \return A list of loops above the given block in its scope, from outer to inner |
266 | */ |
267 | virtual Array<LoopRV> GetLoops(const BlockRV& block_rv) = 0; |
268 | /*! |
269 | * \brief Get the leaf blocks of a specific scope |
270 | * \param block_rv The block where the scope is rooted |
271 | * \return A list of child blocks |
272 | */ |
273 | virtual Array<BlockRV> GetChildBlocks(const BlockRV& block_rv) = 0; |
274 | /*! |
275 | * \brief Get the leaf blocks of under a specific loop |
276 | * \param loop_rv The loop under which collecting is conducted |
277 | * \return A list of child blocks |
278 | */ |
279 | virtual Array<BlockRV> GetChildBlocks(const LoopRV& loop_rv) = 0; |
280 | /*! |
281 | * \brief Get the producer of a specific block, under the same block scope |
282 | * \param block_rv The block in the query |
283 | * \return A list of blocks, the producers of the given block under the same scope of the given |
284 | * block |
285 | */ |
286 | virtual Array<BlockRV> GetProducers(const BlockRV& block_rv) = 0; |
287 | /*! |
288 | * \brief Get the consumers of a specific block, under the same block scope |
289 | * \param block_rv The block to be queried |
290 | * \return A list of blocks, the consumers of the given block under the same scope of the given |
291 | * block |
292 | */ |
293 | virtual Array<BlockRV> GetConsumers(const BlockRV& block_rv) = 0; |
294 | /******** Schedule: Transform loops ********/ |
295 | /*! |
296 | * \brief Fuse a list of consecutive loops into one. It requires: |
297 | * 1) The loops can't have annotations or thread bindings. |
298 | * 2) The (i+1)-th loop must be the only child of the i-th loop. |
299 | * 3) All loops must start with 0. |
300 | * 4) The domain of a loop to be fused cannot depend on another loop to be fused. |
301 | * \param loop_rvs The loops to be fused |
302 | * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings |
303 | * \return The new loop after fusion |
304 | */ |
305 | virtual LoopRV Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_iters = true) = 0; |
306 | /*! |
307 | * \brief Split a loop into a list of consecutive loops. It requires: |
308 | * 1) The loop can't have annotation or thread binding. |
309 | * 2) The loop must start with 0. |
310 | * \param loop_rv The loop to be split |
311 | * \param factors The positive tiling factors, and at most one of which is `NullOpt`, which means |
312 | * that factor is inferred. |
313 | * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings |
314 | * \return The new loops after split |
315 | */ |
316 | virtual Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors, |
317 | bool preserve_unit_iters = true) = 0; |
318 | /*! |
319 | * \brief Reorder a list of loops. It doesn't require the loops to be consecutive. |
320 | * It requires: |
321 | * 1) The loops are in the same chain. That means: the loops can be ordered to [l_1, l_2, ... , |
322 | * l_n] where l_i is an ancestor of l_{i+1} and there are only single-branch loops between |
323 | * l_1 and l_n (which also indicates they are under the same scope). |
324 | * 2) After reordering, the domain of an outer loop cannot depend on any of the inner loops. |
325 | * 3) For every block under the loop nests, its block binding must be affine, and the block |
326 | * variables must be either data parallel or reduction. |
327 | * 4) No duplicated loops are allowed in the arguments. |
328 | * \param ordered_loop_rvs The loops in the new order |
329 | */ |
330 | virtual void Reorder(const Array<LoopRV>& ordered_loop_rvs) = 0; |
331 | /*! |
332 | * \brief Create a new unit loop on top of the specific block. |
333 | * \param block_rv The block above which the new loop is created |
334 | * \return The new loop created |
335 | */ |
336 | virtual LoopRV AddUnitLoop(const BlockRV& block_rv) = 0; |
337 | /*! |
338 | * \brief Create a new unit loop on top of the specific loop. |
339 | * \param loop_rv The loop above which the new loop is created |
340 | * \return The new loop created |
341 | */ |
342 | virtual LoopRV AddUnitLoop(const LoopRV& loop_rv) = 0; |
343 | /******** Schedule: Manipulate ForKind ********/ |
344 | /*! |
345 | * \brief Parallelize the input loop. It requires: |
346 | * 1) The scope block that the loop is in should have stage-pipeline property |
347 | * 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine |
348 | * bindings |
349 | * 3) For each block under the loop, the loop can only be contained in data-parallel block iters' |
350 | * bindings |
351 | * \param loop_rv The loop to be parallelized |
352 | */ |
353 | virtual void Parallel(const LoopRV& loop_rv) = 0; |
354 | /*! |
355 | * \brief Vectorize the input loop. It requires: |
356 | * 1) The scope block that the loop is in should have stage-pipeline property |
357 | * 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine |
358 | * bindings |
359 | * 3) For each block under the loop, the loop can only be contained in data-parallel block iters' |
360 | * bindings |
361 | * \param loop_rv The loop to be vectorized |
362 | */ |
363 | virtual void Vectorize(const LoopRV& loop_rv) = 0; |
364 | /*! |
365 | * \brief Bind the input loop to the given thread axis. It requires: |
366 | * 1) The scope block that the loop is in should have stage-pipeline property |
367 | * 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine |
368 | * bindings |
369 | * 3) For each block under the loop, if the thread axis starts with "threadIdx`, the loop can only |
370 | * be contained in data-parallel block iter and reduction block iters' bindings. Otherwise the |
371 | * loop can only be contained in data-parallel block iters' bindings |
372 | * \param loop_rv The loop to be bound to the thread axis |
373 | * \param thread_axis The thread axis to be bound to the loop |
374 | */ |
375 | virtual void Bind(const LoopRV& loop_rv, const String& thread_axis) = 0; |
376 | /*! |
377 | * \brief Unroll the input loop. It requires nothing |
378 | * \param loop_rv The loop to be unrolled |
379 | */ |
380 | virtual void Unroll(const LoopRV& loop_rv) = 0; |
381 | /******** Schedule: Insert cache stages ********/ |
382 | /*! |
383 | * \brief Create a block that reads a buffer region into a read cache. It requires: |
384 | * 1) There is at most one block who writes the buffer in the scope. |
385 | * 2) The scope block have stage-pipeline property. |
386 | * \param block_rv The consumer block of the target buffer. |
387 | * \param read_buffer_index The index of the buffer in block's read region. |
388 | * \param storage_scope The target storage scope. |
389 | * \param consumer_blocks An optional list of consumers of the cache to rewrite. |
390 | * \return The cache stage block. |
391 | */ |
392 | virtual BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index, |
393 | const String& storage_scope, |
394 | const Array<BlockRV> consumer_blocks = {}) = 0; |
395 | /*! |
396 | * \brief Create a block that writes a buffer region into a write cache. It requires: |
397 | * 1) There is only one block who writes the target buffer. |
398 | * 2) The scope block have stage-pipeline property. |
399 | * \param block_rv The producer of the buffer |
400 | * \param write_buffer_index The index of the buffer in block's write region |
401 | * \param storage_scope The target storage scope |
402 | * \param consumer_blocks An optional list of consumers to read from cache directly. |
403 | * \return The cache stage block. |
404 | */ |
405 | virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, |
406 | const String& storage_scope, |
407 | const Array<BlockRV> consumer_blocks = {}) = 0; |
408 | /*! |
409 | * \brief Create 2 blocks that read&write a buffer region into a read/write cache. |
410 | * It requires the the target block both read & write the target buffer. |
411 | * \param block_rv The target block operates on the target buffer. |
412 | * \param read_buffer_index The index of the buffer in block's read region. |
413 | * \param storage_scope The target storage scope |
414 | * \return The cache stage blocks, cache read block together with cache write block. |
415 | */ |
416 | virtual Array<BlockRV> CacheInplace(const BlockRV& block_rv, int read_buffer_index, |
417 | const String& storage_scope) = 0; |
418 | /*! |
419 | * \brief Create a block to cache precomputed index for later use. |
420 | * if there is no index computation, keep unchanged. |
421 | * \param block_rv The target block |
422 | * \param storage_scope The storage scope of cached block |
423 | * \param cse_thresh The repeat threshold that determines a common sub expr |
424 | * \return The cache stage blocks. |
425 | */ |
426 | virtual Array<BlockRV> CacheIndex(const BlockRV& block_rv, const String& storage_scope, |
427 | int cse_thresh) = 0; |
428 | /*! |
429 | * \brief Create a block that read/write a buffer region into a read/write cache with reindexing. |
430 | * The layout of the cache will be the same as by the iterators of the block that reads/writes the |
431 | * buffer. It requires: |
432 | * 1) There is only one block who reads/writes the target buffer |
433 | * 2) There is only one buffer load/store of this buffer in the block |
434 | * \param block_rv The block operates on the target buffer. |
435 | * \param buffer_index The index of the buffer in block's read or write region. |
436 | * \param buffer_index_type The type of the buffer index, kRead or kWrite. |
437 | * \return The reindex stage block. |
438 | */ |
439 | virtual BlockRV ReIndex(const BlockRV& block_rv, int buffer_index, |
440 | BufferIndexType buffer_index_type) = 0; |
441 | /******** Schedule: Compute location ********/ |
442 | /*! |
443 | * \brief Move a producer block under the specific loop, and regenerate the |
444 | * loops induced by the block so that the buffer region produced by the producer block could |
445 | * cover those regions consumed by its consumer blocks under the given loop. It requires: |
446 | * 1) `block` and `loop` are under the same scope, `loop` is not the ancestor of `block` |
447 | * 2) The scope block has stage-pipeline property |
448 | * 3) The subtree of the scope block, where the given block is in, satisfies the compact dataflow |
449 | * condition. i.e. all the blocks in the scope block's subtree must be either complete block or |
450 | * reduction block |
451 | * 4) The block is not an output block with regard to the scope block, i.e. the buffers written by |
452 | * the block are allocated under the scope block |
453 | * 5) All the consumers of the block are under the given loop |
454 | * \param block_rv The block to be moved |
455 | * \param loop_rv The loop where the block to be moved under |
456 | * \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1 |
457 | * \param index The block index of the loop body subtree blocks: |
458 | * - `index = -1` means inserted into the last possible insertion point; |
459 | * - `index = -2` means inserted into the first possible insertion point; |
460 | * - Otherwise, `index` is a nonnegative number that indicates the insertion point |
461 | */ |
462 | virtual void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops, |
463 | int index = -1) = 0; |
464 | /*! |
465 | * \brief Move a consumer block under the specific loop, and regenerate the |
466 | * loops induced by the block so that the buffer region consumed by the consumer block could |
467 | * cover those regions produced by its producer blocks under the given loop. It requires: |
468 | * 1) `block` and `loop` are under the same scope, `loop` is not the ancestor of `block` |
469 | * 2) The scope block has stage-pipeline property |
470 | * 3) The subtree of the scope block, where the given block is in, satisfies the compact dataflow |
471 | * condition. i.e. all the blocks in the scope block's subtree must be either complete block or |
472 | * reduction block |
473 | * 4) All the producers of the block are under the given loop |
474 | * |
475 | * \param block_rv The block to be moved |
476 | * \param loop_rv The loop where the block to be moved under |
477 | * \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1 |
478 | * \param index The block index of the loop body subtree blocks: |
479 | * - `index = -1` means inserted into the last possible insertion point; |
480 | * - `index = -2` means inserted into the first possible insertion point; |
481 | * - Otherwise, `index` is a nonnegative number that indicates the insertion point |
482 | */ |
483 | virtual void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, |
484 | bool preserve_unit_loops, int index = -1) = 0; |
485 | /*! |
486 | * \brief Inline a block into its consumer(s). It requires: |
487 | * 1) The block is a complete non-root block, which only produces one buffer |
488 | * 2) The block must not be the only leaf in the scope. |
489 | * 3) The body of the block must be a BufferStore statement in the form of, |
490 | * A[i, j, k, ...] = ... |
491 | * where the indices of the LHS are all distinct atomic variables, |
492 | * and no variables other than those indexing variables are allowed in the statement. |
493 | * \param block The block to be inlined to its consumer(s) |
494 | */ |
495 | virtual void ComputeInline(const BlockRV& block) = 0; |
496 | /*! |
497 | * \brief Inline a block into its only producer. It requires: |
498 | * 1) The block is a complete non-root block, which only produces and consumers one buffer |
499 | * 2) The block must not be the only leaf in the scope. |
500 | * 3) The only producer of the block is a read-after-write producer and a complete non-root block |
501 | * 4) The body of the block must be a BufferStore statement in the form of, |
502 | * B[f(i, j, k, ...)] = g(i, j, k, A[i, j, k, ...] ...) |
503 | * where the indices of each `BufferLoad` on the RHS are all distinct atomic variables, |
504 | * and no variables other than those indexing variables are allowed in the statement. |
505 | * \param block The block to be inlined to its producer |
506 | */ |
507 | virtual void ReverseComputeInline(const BlockRV& block) = 0; |
508 | /******** Schedule: Reduction ********/ |
509 | /*! |
510 | * \brief Decompose a reduction block into two separate blocks. |
511 | * a) The init block, which is translated from the init statement of the reduction block; |
512 | * b) The update block, which is the original block without init statement. |
513 | * |
514 | * The init block is inserted right before the given loop. |
515 | * |
516 | * The schedule primitive requires: |
517 | * 1) The input block is a reduction block. |
518 | * 2) The input loop is the ancestor of the block. |
519 | * 3) The input loop is not lower than all the loops related to reduce block var. |
520 | * \param block_rv The reduction block to be decomposed |
521 | * \param loop_rv The loop above which the init block is inserted before. |
522 | * \return The init block |
523 | */ |
524 | virtual BlockRV DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) = 0; |
525 | /*! |
526 | * \brief Factorize an associative reduction block by the specified loop. |
527 | * \details An associative reduction cannot be parallelized directly, |
528 | * because it leads to potential race condition during accumulation. |
529 | * Alternatively, the reduction could be factorized on a loop with the following steps: |
530 | * - Step 1: evenly slice the reduction into `n` separate chunks, where `n` is the loop extent |
531 | * - Step 2: compute the chunks separately and write the result into `n` intermediate buffers; |
532 | * - Step 3: accumulate the `n` separate buffer into the result buffer. |
533 | * Note that the Step 2 above introduces opportunities for parallelization. |
534 | * RFactor is a schedule primitive that implements the transformation described above. |
535 | * \param loop_rv The loop outside block we want to do rfactor |
536 | * \param factor_axis The position where the new dimension is placed in the new introduced rfactor |
537 | * buffer. Suppose the original reduction block writes to buffer `B` with |
538 | * ndim(B) dimensions, then `factor_axis` should be in range `[-ndim(B) - 1, |
539 | * ndim(B)]`, and the negative index will be normalized to a non-negative one |
540 | * \return The rfactor block |
541 | */ |
542 | virtual BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) = 0; |
543 | /******** Schedule: Block annotation ********/ |
544 | /*! |
545 | * \brief Set alignment requirement for specific dimension such that |
546 | * stride[axis] == k * factor + offset for some k. This is useful to set memory layout for |
547 | * more friendly memory access pattern. For example, we can set alignment to be factor=2, |
548 | * offset=1 to avoid bank conflict for thread access on higher dimension in GPU shared |
549 | * memory. |
550 | * \param block_rv The producer block of the buffer |
551 | * \param buffer_index The index of the buffer in block's write region |
552 | * \param axis The dimension to be specified for alignment |
553 | * \param factor The factor multiple of alignment |
554 | * \param offset The required offset factor |
555 | */ |
556 | virtual void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor, |
557 | int offset) = 0; |
558 | /*! |
559 | * \brief Set the storage scope of a buffer, where the buffer is specified by the a block and a |
560 | * write-index |
561 | * \param block_rv The producer block of the buffer |
562 | * \param buffer_index The index of the buffer in block's write region |
563 | * \param storage_scope The storage scope to be set |
564 | */ |
565 | virtual void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) = 0; |
566 | /******** Schedule: Blockize & Tensorize ********/ |
567 | /*! |
568 | * \brief Convert the subtree rooted at a specific loop into a block. |
569 | * \param loop_rv the root of the subtree |
570 | * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings |
571 | * \return the new block |
572 | */ |
573 | virtual BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters = true) = 0; |
574 | /*! |
575 | * \brief Tensorize the computation enclosed by loop with the tensor intrin. |
576 | * \param loop_rv The loop to be tensorized |
577 | * \param intrin Name of the tensor intrinsic |
578 | * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings |
579 | */ |
580 | virtual void Tensorize(const LoopRV& loop_rv, const String& intrin, |
581 | bool preserve_unit_iters = true) = 0; |
582 | /*! |
583 | * \brief Tensorize the computation enclosed by loop with the tensor intrin. |
584 | * \param block_rv The block to be tensorized |
585 | * \param intrin Name of the tensor intrinsic |
586 | * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings |
587 | */ |
588 | virtual void Tensorize(const BlockRV& block_rv, const String& intrin, |
589 | bool preserve_unit_iters = true) = 0; |
590 | |
591 | /******** Schedule: Annotation ********/ |
592 | /*! |
593 | * \brief Annotate a loop with a key value pair |
594 | * \param loop_rv The loop to be annotated |
595 | * \param ann_key The annotation key |
596 | * \param ann_val The annotation value, a string or a ExprRV |
597 | */ |
598 | virtual void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) = 0; |
599 | /*! |
600 | * \brief Annotate a block with a key value pair |
601 | * \param block_rv The block to be annotated |
602 | * \param ann_key The annotation key |
603 | * \param ann_val The annotation value, a string or a ExprRV |
604 | */ |
605 | virtual void Annotate(const BlockRV& block_rv, const String& ann_key, |
606 | const ObjectRef& ann_val) = 0; |
607 | /*! |
608 | * \brief Unannotate a loop's annotation with key ann_key |
609 | * \param loop_rv The loop to be unannotated |
610 | * \param ann_key The annotation key |
611 | */ |
612 | virtual void Unannotate(const LoopRV& loop_rv, const String& ann_key) = 0; |
613 | /*! |
614 | * \brief Unannotate a block's annotation with key ann_key |
615 | * \param block_rv The block to be unannotated |
616 | * \param ann_key The annotation key |
617 | */ |
618 | virtual void Unannotate(const BlockRV& block_rv, const String& ann_key) = 0; |
619 | |
620 | /******** Schedule: Layout transformation ********/ |
621 | /*! |
622 | * \brief Apply a transformation represented by IndexMap to buffer |
623 | * \details The indices and the access region to the target buffer is transformed by the given |
624 | * index_map. The index_map is used to infer the new shape of the buffer. Buffer must be either |
625 | * a function parameter, or allocated in a block (it cannot be a buffer subregion created via |
626 | * 'match_buffer'). |
627 | * \param block_rv The block that accesses the target buffer. |
628 | * \param buffer_index The index of the buffer in block's read or write region. |
629 | * \param buffer_index_type The type of the buffer index, kRead or kWrite. |
630 | * \param index_map The transformation to apply. |
631 | * |
632 | * \param pad_value The value to write into padding introduced by |
633 | * the transformation. If the schedule contains a producer block |
634 | * for the specified buffer, the pad value will be written as |
635 | * part of the producer block if possible, or after the producer |
636 | * block otherwise. Otherwise, if the buffer is an input, will |
637 | * insert an annotation block to state that the padding contains |
638 | * the known value. |
639 | * |
640 | * Note: If applied to an input buffer, the calling scope is |
641 | * responsible for ensuring that the pad_value is present. |
642 | * Algebraic symplifications, branch elimination, and other |
643 | * optimizations may assume that this precondition is met, and |
644 | * may result in incorrect results being returned. |
645 | */ |
646 | virtual void TransformLayout(const BlockRV& block_rv, int buffer_index, |
647 | BufferIndexType buffer_index_type, const IndexMap& index_map, |
648 | const Optional<IndexMap>& pad_value = NullOpt) = 0; |
649 | |
650 | /*! |
651 | * \brief Apply a transformation represented by IndexMap to block |
652 | * \details The block iters and the block body are transformed by the given index_map. |
653 | * Outer loops corresponding to each new block iter are regenerated. |
654 | * The index_map is required to be bijective affine since we need its inverse mapping. |
655 | * \param block_rv The block to be transformed |
656 | * \param index_map The transformation to apply. |
657 | */ |
658 | virtual void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) = 0; |
659 | |
660 | /*! |
661 | * \brief Set the axis separator of a buffer, where the buffer is specified by a block and a read |
662 | * or write index |
663 | * \param block_rv The block that accesses the target buffer. |
664 | * \param buffer_index The index of the buffer in block's read or write region. |
665 | * \param buffer_index_type The type of the buffer index, kRead or kWrite. |
666 | * \param axis_separators The axis separator of the buffer |
667 | */ |
668 | virtual void SetAxisSeparator(const BlockRV& block_rv, int buffer_index, |
669 | BufferIndexType buffer_index_type, |
670 | const Array<IntImm>& axis_separators) = 0; |
671 | |
672 | /******** Schedule: Padding ********/ |
673 | /*! |
674 | * \brief Decompose a padding block into a block filling const pad values and a block |
675 | * writing in-bound values. |
676 | * \param block_rv The block that match the padding pattern. |
677 | * \param loop_rv The loop above which the const filling block is inserted before. |
678 | * \return The const pad value filling block. |
679 | */ |
680 | virtual BlockRV DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) = 0; |
681 | |
682 | /*! |
683 | * \brief Pad the computation of Einsum. |
684 | * \param block_rv The block that matches the Einsum pattern. |
685 | * \param padding The padding for each block iter. |
686 | * \details This schedule primitives identifies the Einsum pattern in the block body, and find its |
687 | * producer blocks. It then pads the computation of the Einsum pattern and its producer blocks. |
688 | * The output buffer and the producer buffer is resized according to the padding size. It requires |
689 | * the output buffer and the producer buffer to be allocated inside the PrimFunc. |
690 | * |
691 | * The padding is a list of non-negative integers, each element corresponds to the padding for |
692 | * each block iter in the order of block iters. The block and its producer blocks should have |
693 | * trivial bindings, i.e. each block iter is bound to a single loop variable. After padding, the |
694 | * block iter extent and the corresponding outer loop is extended by the padding size. |
695 | * |
696 | * The size of the producer buffers are infered from the padding size of the Einsum computation. |
697 | * The producer buffers are padded by the initial value of the corresponding reduction. |
698 | */ |
699 | virtual void PadEinsum(const BlockRV& block_rv, const Array<Integer>& padding) = 0; |
700 | |
701 | /******** Schedule: Buffer transformation ********/ |
702 | /*! |
703 | * \brief Compute the target buffer via rolling buffering. |
704 | * \details This primitive selects the outermost rollable axis with a positive bound overlap that |
705 | * appears in the block's ancestor loops as `rolling axis`, fold and circularize the buffer along |
706 | * the rolling dimension, append block predicate to avoid recomputing overlapping elements. |
707 | * It requires: |
708 | * 1) The buffer to be an intermediate buffer defined via `alloc_buffer`. |
709 | * 2) The LCA of the producer and consumer of the buffer is a for loop, typically, |
710 | * the producer and consumer of the buffer are cascaded through compute_at. |
711 | * 3) The access region of the buffer has at least one dimension that contains |
712 | * a positive bound overlap. |
713 | * \param block_rv The producer block of the buffer. |
714 | * \param write_buffer_index The index of the buffer in block's write region. |
715 | */ |
716 | virtual void RollingBuffer(const BlockRV& block_rv, int write_buffer_index) = 0; |
717 | |
718 | /******** Schedule: Misc ********/ |
719 | /*! \brief A no-op that marks the start of postprocessing phase of scheduling */ |
720 | virtual void EnterPostproc() = 0; |
721 | }; |
722 | |
723 | /*! |
724 | * \brief Managed reference to ScheduleNode |
725 | * |
726 | * A schedule is a set of transformations that change the order of computation but |
727 | * preserve the semantics of computation. Some example of schedules: |
728 | * 1) Split a loop into two; |
729 | * 2) Reorder two loops; |
730 | * 3) Inline the computation of a specific buffer into its consumer |
731 | * |
732 | * The schedule class stores auxiliary information to schedule correctly and efficiently. |
733 | * |
734 | * Link to tutorial: https://tvm.apache.org/docs/tutorials/language/schedule_primitives.html |
735 | * |
736 | * \sa ScheduleNode |
737 | */ |
738 | class Schedule : public runtime::ObjectRef { |
739 | public: |
740 | /*! |
741 | * \brief Construct a concrete TensorIR schedule from an IRModule |
742 | * \param mod The IRModule to be scheduled |
743 | * \param seed The seed value for schedule's random state |
744 | * \param debug_mask Do extra correctness checking after the class creation |
745 | * and each time after calling the Replace method. |
746 | * \param error_render_level The level of error rendering |
747 | * \return The concrete schedule created |
748 | * \sa ScheduleDebugMask |
749 | * \note The checks performed includes: |
750 | * 1) VerifySRefTree |
751 | * 2) VerifyCachedFlags |
752 | */ |
753 | TVM_DLL static Schedule Concrete(IRModule mod, support::LinearCongruentialEngine::TRandState seed, |
754 | int debug_mask, ScheduleErrorRenderLevel error_render_level); |
755 | /*! |
756 | * \brief Construct a traced concrete TensorIR schedule from an IRModule |
757 | * \param mod The IRModule to be scheduled |
758 | * \param seed The seed value for schedule's random state |
759 | * \param debug_mask Do extra correctness checking after the class creation |
760 | * and each time after calling the Replace method. |
761 | * \param error_render_level The level of error rendering |
762 | * \return The concrete schedule created |
763 | * \sa ScheduleDebugMask |
764 | * \note The checks performed include: |
765 | * 1) VerifySRefTree |
766 | * 2) VerifyCachedFlags |
767 | */ |
768 | TVM_DLL static Schedule Traced(IRModule mod, support::LinearCongruentialEngine::TRandState seed, |
769 | int debug_mask, ScheduleErrorRenderLevel error_render_level); |
770 | TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef, ScheduleNode); |
771 | }; |
772 | |
773 | } // namespace tir |
774 | } // namespace tvm |
775 | |
776 | #endif // TVM_TIR_SCHEDULE_SCHEDULE_H_ |
777 | |