1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#ifndef TVM_TIR_SCHEDULE_SCHEDULE_H_
20#define TVM_TIR_SCHEDULE_SCHEDULE_H_
21
22#include <tvm/support/random_engine.h>
23#include <tvm/tir/index_map.h>
24#include <tvm/tir/schedule/state.h>
25#include <tvm/tir/schedule/trace.h>
26
27namespace tvm {
28namespace tir {
29
30/*! \brief The level of detailed error message rendering */
31enum class ScheduleErrorRenderLevel : int32_t {
32 /*! \brief Render a detailed error message */
33 kDetail = 0,
34 /*! \brief Render the error in fast mode */
35 kFast = 1,
36 /*! \brief No error message at all */
37 kNone = 2,
38};
39
40/*! \brief Type of buffer index */
41enum class BufferIndexType : int32_t {
42 /*! \brief Index of a read buffer */
43 kRead = 0,
44 /*! \brief Index of a written buffer */
45 kWrite = 1,
46};
47
48/**************** Random variable: BlockRV ****************/
49
50/*! \brief A random variable that evaluates to a TensorIR block */
51class BlockRVNode : public runtime::Object {
52 public:
53 void VisitAttrs(tvm::AttrVisitor* v) {}
54 static constexpr const char* _type_key = "tir.BlockRV";
55 TVM_DECLARE_FINAL_OBJECT_INFO(BlockRVNode, runtime::Object);
56};
57
58/*!
59 * \brief Managed reference to BlockRVNode
60 * \sa BlockRVNode
61 */
62class BlockRV : public runtime::ObjectRef {
63 public:
64 /*! \brief Constructor */
65 TVM_DLL BlockRV();
66 TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BlockRV, runtime::ObjectRef, BlockRVNode);
67};
68
69/**************** Random variable: LoopRV ****************/
70
71/*! \brief A random variable that evaluates to a TensorIR for loop */
72class LoopRVNode : public runtime::Object {
73 public:
74 void VisitAttrs(tvm::AttrVisitor* v) {}
75 static constexpr const char* _type_key = "tir.LoopRV";
76 TVM_DECLARE_FINAL_OBJECT_INFO(LoopRVNode, runtime::Object);
77};
78
79/*!
80 * \brief Managed reference to LoopRVNode
81 * \sa LoopRVNode
82 */
83class LoopRV : public runtime::ObjectRef {
84 public:
85 /*! \brief Constructor */
86 TVM_DLL LoopRV();
87 TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LoopRV, runtime::ObjectRef, LoopRVNode);
88};
89
90/**************** Random variable: ExprRV ****************/
91
92/*! \brief An expr random variable */
93using ExprRV = PrimExpr;
94
95using ExprRVNode = PrimExprNode;
96
97/**************** The Schedule class ****************/
98
99class Schedule;
100
101/*! \brief The user-facing schedule class */
102class ScheduleNode : public runtime::Object {
103 friend class Schedule;
104
105 public:
106 virtual ~ScheduleNode() = default;
107
108 static constexpr const char* _type_key = "tir.Schedule";
109 TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleNode, runtime::Object);
110
111 public:
112 /*! \brief Get the IRModule associated with this schedule. */
113 virtual IRModule mod() const { return state()->mod; }
114 /*! \return The internal state of scheduling */
115 virtual ScheduleState state() const = 0;
116 /*! \return The internally maintained trace of scheduling program execution */
117 virtual Optional<Trace> trace() const = 0;
118 /*!
119 * \brief Instruct the schedule to work on a function in the IRModule.
120 *
121 * By default, the schedule works on the function with the name "main", or the only function in
122 * the IRModule if there is only one. If there is multiple functions in the IRModule, and none of
123 * their names are "main", users will have to call this method to explicitly specify which
124 * function to work on.
125 *
126 * This sugar function will guide the `GetBlock` method if its `func_name` is not specified.
127 *
128 * \param func_name The name of the function to be working on
129 *
130 * \sa GetBlock
131 */
132 virtual void WorkOn(const String& func_name) = 0;
133 /*!
134 * \brief Returns a copy of the schedule, including both its state and its symbol table,
135 * guaranteeing that
136 * 1) SRef tree is completely reconstructed;
137 * 2) The IRModule being scheduled is not modified;
138 * 3) All the random variables are valid in the copy, pointing to the corresponding sref
139 * reconstructed
140 */
141 virtual Schedule Copy() = 0;
142 /*!
143 * \brief Seed the randomness
144 * \param seed The new random seed, -1 if use device random, otherwise non-negative
145 */
146 virtual void Seed(support::LinearCongruentialEngine::TRandState seed) = 0;
147 /*! \brief Fork the random state */
148 virtual support::LinearCongruentialEngine::TRandState ForkSeed() = 0;
149
150 public:
151 /******** Lookup/Remove random variables ********/
152 /*!
153 * \brief Get the block corresponding to the specific BlockRV
154 * \param block_rv The BlockRV to be looked up
155 * \return The corresponding block
156 */
157 virtual Block Get(const BlockRV& block_rv) const = 0;
158 /*!
159 * \brief Get the for loop corresponding to the specific LoopRV
160 * \param loop_rv The LoopRV to be looked up
161 * \return The corresponding for loop
162 */
163 virtual For Get(const LoopRV& loop_rv) const = 0;
164 /*!
165 * \brief Get the expr corresponding to the specific random variable
166 * \param expr_rv The random variable to be looked up
167 * \return The corresponding expr
168 */
169 virtual PrimExpr Get(const ExprRV& expr_rv) const = 0;
170 /*!
171 * \brief Get the block sref corresponding to the specific BlockRV
172 * \param block_rv The BlockRV to be looked up
173 * \return The corresponding block sref
174 */
175 virtual StmtSRef GetSRef(const BlockRV& block_rv) const = 0;
176 /*!
177 * \brief Get the loop sref corresponding to the specific LoopRV
178 * \param loop_rv The LoopRV to be looked up
179 * \return The corresponding loop sref
180 */
181 virtual StmtSRef GetSRef(const LoopRV& loop_rv) const = 0;
182 /*!
183 * \brief Check the existance of a specific BlockRV
184 * \param block_rv The BlockRV to be looked up
185 * \return Whether the corresponding block exists
186 */
187 virtual bool HasBlock(const BlockRV& block_rv) const = 0;
188 /*!
189 * \brief Get the block/loop sref corresponding to the specific statement
190 * \param stmt The statement to be looked up
191 * \return The corresponding block/loop sref
192 */
193 virtual StmtSRef GetSRef(const StmtNode* stmt) const;
194 /*!
195 * \brief Get the block/loop sref corresponding to the specific statement
196 * \param stmt The statement to be looked up
197 * \return The corresponding block/loop sref
198 */
199 StmtSRef GetSRef(const Stmt& stmt) const { return this->GetSRef(stmt.get()); }
200 /*!
201 * \brief Remove a block random variable from the symbol table
202 * \param block_rv The random variable to be removed
203 */
204 virtual void RemoveRV(const BlockRV& block_rv) = 0;
205 /*!
206 * \brief Remove a loop random variable from the symbol table
207 * \param loop_rv The random variable to be removed
208 */
209 virtual void RemoveRV(const LoopRV& loop_rv) = 0;
210 /*!
211 * \brief Remove an integer random variable from the symbol table
212 * \param expr_rv The random variable to be removed
213 */
214 virtual void RemoveRV(const ExprRV& expr_rv) = 0;
215
216 public:
217 /******** Schedule: Sampling ********/
218 /*!
219 * \brief Sample an integer given the probability distribution
220 * \param candidates The candidates
221 * \param probs The probability distribution of the candidates
222 * \param decision The sampling decision
223 * \return The random variable sampled from candidates
224 */
225 virtual ExprRV SampleCategorical(const Array<Integer>& candidates, const Array<FloatImm>& probs,
226 Optional<Integer> decision = NullOpt) = 0;
227 /*!
228 * \brief Sample the factors to perfect tile a specific loop
229 * \param loop_rv The loop to be tiled
230 * \param n The number of tiles to be sampled
231 * \param max_innermost_factor The maximum tile size allowed to be sampled in the innermost loop
232 * \param decision The sampling decision
233 * \return A list of length `n`, the random perfect tile sizes sampled
234 */
235 virtual Array<ExprRV> SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor,
236 Optional<Array<Integer>> decision = NullOpt) = 0;
237 /*!
238 * \brief Sample a compute-at location of the given block
239 * \param block_rv The block whose compute-at location is to be sampled
240 * \param decision The sampling decision
241 * \return The sampled loop where the input block is to be computed at
242 */
243 virtual LoopRV SampleComputeLocation(const BlockRV& block_rv,
244 Optional<Integer> decision = NullOpt) = 0;
245
246 /******** Schedule: Get blocks & loops ********/
247 /*!
248 * \brief Retrieve a block in a specific function with its name
249 *
250 * By default, if `func_name` is not specified, the schedule will search for the block in the
251 * function that is currently being "worked on". To switch the function to be worked on, use
252 * `WorkOn` before calling this method.
253 *
254 * \param name The name of the block to be retrieved
255 * \param func_name The name of the function
256 * \return The block retrieved
257 * \note Indexing error is raised if 0 or multiple blocks exist with the specific name
258 *
259 * \sa WorkOn
260 */
261 virtual BlockRV GetBlock(const String& name, const Optional<String>& func_name = NullOpt) = 0;
262 /*!
263 * \brief Get the parent loops of the block in its scope, from outer to inner
264 * \param block_rv The query block
265 * \return A list of loops above the given block in its scope, from outer to inner
266 */
267 virtual Array<LoopRV> GetLoops(const BlockRV& block_rv) = 0;
268 /*!
269 * \brief Get the leaf blocks of a specific scope
270 * \param block_rv The block where the scope is rooted
271 * \return A list of child blocks
272 */
273 virtual Array<BlockRV> GetChildBlocks(const BlockRV& block_rv) = 0;
274 /*!
275 * \brief Get the leaf blocks of under a specific loop
276 * \param loop_rv The loop under which collecting is conducted
277 * \return A list of child blocks
278 */
279 virtual Array<BlockRV> GetChildBlocks(const LoopRV& loop_rv) = 0;
280 /*!
281 * \brief Get the producer of a specific block, under the same block scope
282 * \param block_rv The block in the query
283 * \return A list of blocks, the producers of the given block under the same scope of the given
284 * block
285 */
286 virtual Array<BlockRV> GetProducers(const BlockRV& block_rv) = 0;
287 /*!
288 * \brief Get the consumers of a specific block, under the same block scope
289 * \param block_rv The block to be queried
290 * \return A list of blocks, the consumers of the given block under the same scope of the given
291 * block
292 */
293 virtual Array<BlockRV> GetConsumers(const BlockRV& block_rv) = 0;
294 /******** Schedule: Transform loops ********/
295 /*!
296 * \brief Fuse a list of consecutive loops into one. It requires:
297 * 1) The loops can't have annotations or thread bindings.
298 * 2) The (i+1)-th loop must be the only child of the i-th loop.
299 * 3) All loops must start with 0.
300 * 4) The domain of a loop to be fused cannot depend on another loop to be fused.
301 * \param loop_rvs The loops to be fused
302 * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
303 * \return The new loop after fusion
304 */
305 virtual LoopRV Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_iters = true) = 0;
306 /*!
307 * \brief Split a loop into a list of consecutive loops. It requires:
308 * 1) The loop can't have annotation or thread binding.
309 * 2) The loop must start with 0.
310 * \param loop_rv The loop to be split
311 * \param factors The positive tiling factors, and at most one of which is `NullOpt`, which means
312 * that factor is inferred.
313 * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
314 * \return The new loops after split
315 */
316 virtual Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors,
317 bool preserve_unit_iters = true) = 0;
318 /*!
319 * \brief Reorder a list of loops. It doesn't require the loops to be consecutive.
320 * It requires:
321 * 1) The loops are in the same chain. That means: the loops can be ordered to [l_1, l_2, ... ,
322 * l_n] where l_i is an ancestor of l_{i+1} and there are only single-branch loops between
323 * l_1 and l_n (which also indicates they are under the same scope).
324 * 2) After reordering, the domain of an outer loop cannot depend on any of the inner loops.
325 * 3) For every block under the loop nests, its block binding must be affine, and the block
326 * variables must be either data parallel or reduction.
327 * 4) No duplicated loops are allowed in the arguments.
328 * \param ordered_loop_rvs The loops in the new order
329 */
330 virtual void Reorder(const Array<LoopRV>& ordered_loop_rvs) = 0;
331 /*!
332 * \brief Create a new unit loop on top of the specific block.
333 * \param block_rv The block above which the new loop is created
334 * \return The new loop created
335 */
336 virtual LoopRV AddUnitLoop(const BlockRV& block_rv) = 0;
337 /*!
338 * \brief Create a new unit loop on top of the specific loop.
339 * \param loop_rv The loop above which the new loop is created
340 * \return The new loop created
341 */
342 virtual LoopRV AddUnitLoop(const LoopRV& loop_rv) = 0;
343 /******** Schedule: Manipulate ForKind ********/
344 /*!
345 * \brief Parallelize the input loop. It requires:
346 * 1) The scope block that the loop is in should have stage-pipeline property
347 * 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine
348 * bindings
349 * 3) For each block under the loop, the loop can only be contained in data-parallel block iters'
350 * bindings
351 * \param loop_rv The loop to be parallelized
352 */
353 virtual void Parallel(const LoopRV& loop_rv) = 0;
354 /*!
355 * \brief Vectorize the input loop. It requires:
356 * 1) The scope block that the loop is in should have stage-pipeline property
357 * 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine
358 * bindings
359 * 3) For each block under the loop, the loop can only be contained in data-parallel block iters'
360 * bindings
361 * \param loop_rv The loop to be vectorized
362 */
363 virtual void Vectorize(const LoopRV& loop_rv) = 0;
364 /*!
365 * \brief Bind the input loop to the given thread axis. It requires:
366 * 1) The scope block that the loop is in should have stage-pipeline property
367 * 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine
368 * bindings
369 * 3) For each block under the loop, if the thread axis starts with "threadIdx`, the loop can only
370 * be contained in data-parallel block iter and reduction block iters' bindings. Otherwise the
371 * loop can only be contained in data-parallel block iters' bindings
372 * \param loop_rv The loop to be bound to the thread axis
373 * \param thread_axis The thread axis to be bound to the loop
374 */
375 virtual void Bind(const LoopRV& loop_rv, const String& thread_axis) = 0;
376 /*!
377 * \brief Unroll the input loop. It requires nothing
378 * \param loop_rv The loop to be unrolled
379 */
380 virtual void Unroll(const LoopRV& loop_rv) = 0;
381 /******** Schedule: Insert cache stages ********/
382 /*!
383 * \brief Create a block that reads a buffer region into a read cache. It requires:
384 * 1) There is at most one block who writes the buffer in the scope.
385 * 2) The scope block have stage-pipeline property.
386 * \param block_rv The consumer block of the target buffer.
387 * \param read_buffer_index The index of the buffer in block's read region.
388 * \param storage_scope The target storage scope.
389 * \param consumer_blocks An optional list of consumers of the cache to rewrite.
390 * \return The cache stage block.
391 */
392 virtual BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index,
393 const String& storage_scope,
394 const Array<BlockRV> consumer_blocks = {}) = 0;
395 /*!
396 * \brief Create a block that writes a buffer region into a write cache. It requires:
397 * 1) There is only one block who writes the target buffer.
398 * 2) The scope block have stage-pipeline property.
399 * \param block_rv The producer of the buffer
400 * \param write_buffer_index The index of the buffer in block's write region
401 * \param storage_scope The target storage scope
402 * \param consumer_blocks An optional list of consumers to read from cache directly.
403 * \return The cache stage block.
404 */
405 virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
406 const String& storage_scope,
407 const Array<BlockRV> consumer_blocks = {}) = 0;
408 /*!
409 * \brief Create 2 blocks that read&write a buffer region into a read/write cache.
410 * It requires the the target block both read & write the target buffer.
411 * \param block_rv The target block operates on the target buffer.
412 * \param read_buffer_index The index of the buffer in block's read region.
413 * \param storage_scope The target storage scope
414 * \return The cache stage blocks, cache read block together with cache write block.
415 */
416 virtual Array<BlockRV> CacheInplace(const BlockRV& block_rv, int read_buffer_index,
417 const String& storage_scope) = 0;
418 /*!
419 * \brief Create a block to cache precomputed index for later use.
420 * if there is no index computation, keep unchanged.
421 * \param block_rv The target block
422 * \param storage_scope The storage scope of cached block
423 * \param cse_thresh The repeat threshold that determines a common sub expr
424 * \return The cache stage blocks.
425 */
426 virtual Array<BlockRV> CacheIndex(const BlockRV& block_rv, const String& storage_scope,
427 int cse_thresh) = 0;
428 /*!
429 * \brief Create a block that read/write a buffer region into a read/write cache with reindexing.
430 * The layout of the cache will be the same as by the iterators of the block that reads/writes the
431 * buffer. It requires:
432 * 1) There is only one block who reads/writes the target buffer
433 * 2) There is only one buffer load/store of this buffer in the block
434 * \param block_rv The block operates on the target buffer.
435 * \param buffer_index The index of the buffer in block's read or write region.
436 * \param buffer_index_type The type of the buffer index, kRead or kWrite.
437 * \return The reindex stage block.
438 */
439 virtual BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
440 BufferIndexType buffer_index_type) = 0;
441 /******** Schedule: Compute location ********/
442 /*!
443 * \brief Move a producer block under the specific loop, and regenerate the
444 * loops induced by the block so that the buffer region produced by the producer block could
445 * cover those regions consumed by its consumer blocks under the given loop. It requires:
446 * 1) `block` and `loop` are under the same scope, `loop` is not the ancestor of `block`
447 * 2) The scope block has stage-pipeline property
448 * 3) The subtree of the scope block, where the given block is in, satisfies the compact dataflow
449 * condition. i.e. all the blocks in the scope block's subtree must be either complete block or
450 * reduction block
451 * 4) The block is not an output block with regard to the scope block, i.e. the buffers written by
452 * the block are allocated under the scope block
453 * 5) All the consumers of the block are under the given loop
454 * \param block_rv The block to be moved
455 * \param loop_rv The loop where the block to be moved under
456 * \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1
457 * \param index The block index of the loop body subtree blocks:
458 * - `index = -1` means inserted into the last possible insertion point;
459 * - `index = -2` means inserted into the first possible insertion point;
460 * - Otherwise, `index` is a nonnegative number that indicates the insertion point
461 */
462 virtual void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops,
463 int index = -1) = 0;
464 /*!
465 * \brief Move a consumer block under the specific loop, and regenerate the
466 * loops induced by the block so that the buffer region consumed by the consumer block could
467 * cover those regions produced by its producer blocks under the given loop. It requires:
468 * 1) `block` and `loop` are under the same scope, `loop` is not the ancestor of `block`
469 * 2) The scope block has stage-pipeline property
470 * 3) The subtree of the scope block, where the given block is in, satisfies the compact dataflow
471 * condition. i.e. all the blocks in the scope block's subtree must be either complete block or
472 * reduction block
473 * 4) All the producers of the block are under the given loop
474 *
475 * \param block_rv The block to be moved
476 * \param loop_rv The loop where the block to be moved under
477 * \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1
478 * \param index The block index of the loop body subtree blocks:
479 * - `index = -1` means inserted into the last possible insertion point;
480 * - `index = -2` means inserted into the first possible insertion point;
481 * - Otherwise, `index` is a nonnegative number that indicates the insertion point
482 */
483 virtual void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
484 bool preserve_unit_loops, int index = -1) = 0;
485 /*!
486 * \brief Inline a block into its consumer(s). It requires:
487 * 1) The block is a complete non-root block, which only produces one buffer
488 * 2) The block must not be the only leaf in the scope.
489 * 3) The body of the block must be a BufferStore statement in the form of,
490 * A[i, j, k, ...] = ...
491 * where the indices of the LHS are all distinct atomic variables,
492 * and no variables other than those indexing variables are allowed in the statement.
493 * \param block The block to be inlined to its consumer(s)
494 */
495 virtual void ComputeInline(const BlockRV& block) = 0;
496 /*!
497 * \brief Inline a block into its only producer. It requires:
498 * 1) The block is a complete non-root block, which only produces and consumers one buffer
499 * 2) The block must not be the only leaf in the scope.
500 * 3) The only producer of the block is a read-after-write producer and a complete non-root block
501 * 4) The body of the block must be a BufferStore statement in the form of,
502 * B[f(i, j, k, ...)] = g(i, j, k, A[i, j, k, ...] ...)
503 * where the indices of each `BufferLoad` on the RHS are all distinct atomic variables,
504 * and no variables other than those indexing variables are allowed in the statement.
505 * \param block The block to be inlined to its producer
506 */
507 virtual void ReverseComputeInline(const BlockRV& block) = 0;
508 /******** Schedule: Reduction ********/
509 /*!
510 * \brief Decompose a reduction block into two separate blocks.
511 * a) The init block, which is translated from the init statement of the reduction block;
512 * b) The update block, which is the original block without init statement.
513 *
514 * The init block is inserted right before the given loop.
515 *
516 * The schedule primitive requires:
517 * 1) The input block is a reduction block.
518 * 2) The input loop is the ancestor of the block.
519 * 3) The input loop is not lower than all the loops related to reduce block var.
520 * \param block_rv The reduction block to be decomposed
521 * \param loop_rv The loop above which the init block is inserted before.
522 * \return The init block
523 */
524 virtual BlockRV DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) = 0;
525 /*!
526 * \brief Factorize an associative reduction block by the specified loop.
527 * \details An associative reduction cannot be parallelized directly,
528 * because it leads to potential race condition during accumulation.
529 * Alternatively, the reduction could be factorized on a loop with the following steps:
530 * - Step 1: evenly slice the reduction into `n` separate chunks, where `n` is the loop extent
531 * - Step 2: compute the chunks separately and write the result into `n` intermediate buffers;
532 * - Step 3: accumulate the `n` separate buffer into the result buffer.
533 * Note that the Step 2 above introduces opportunities for parallelization.
534 * RFactor is a schedule primitive that implements the transformation described above.
535 * \param loop_rv The loop outside block we want to do rfactor
536 * \param factor_axis The position where the new dimension is placed in the new introduced rfactor
537 * buffer. Suppose the original reduction block writes to buffer `B` with
538 * ndim(B) dimensions, then `factor_axis` should be in range `[-ndim(B) - 1,
539 * ndim(B)]`, and the negative index will be normalized to a non-negative one
540 * \return The rfactor block
541 */
542 virtual BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) = 0;
543 /******** Schedule: Block annotation ********/
544 /*!
545 * \brief Set alignment requirement for specific dimension such that
546 * stride[axis] == k * factor + offset for some k. This is useful to set memory layout for
547 * more friendly memory access pattern. For example, we can set alignment to be factor=2,
548 * offset=1 to avoid bank conflict for thread access on higher dimension in GPU shared
549 * memory.
550 * \param block_rv The producer block of the buffer
551 * \param buffer_index The index of the buffer in block's write region
552 * \param axis The dimension to be specified for alignment
553 * \param factor The factor multiple of alignment
554 * \param offset The required offset factor
555 */
556 virtual void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor,
557 int offset) = 0;
558 /*!
559 * \brief Set the storage scope of a buffer, where the buffer is specified by the a block and a
560 * write-index
561 * \param block_rv The producer block of the buffer
562 * \param buffer_index The index of the buffer in block's write region
563 * \param storage_scope The storage scope to be set
564 */
565 virtual void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) = 0;
566 /******** Schedule: Blockize & Tensorize ********/
567 /*!
568 * \brief Convert the subtree rooted at a specific loop into a block.
569 * \param loop_rv the root of the subtree
570 * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
571 * \return the new block
572 */
573 virtual BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters = true) = 0;
574 /*!
575 * \brief Tensorize the computation enclosed by loop with the tensor intrin.
576 * \param loop_rv The loop to be tensorized
577 * \param intrin Name of the tensor intrinsic
578 * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
579 */
580 virtual void Tensorize(const LoopRV& loop_rv, const String& intrin,
581 bool preserve_unit_iters = true) = 0;
582 /*!
583 * \brief Tensorize the computation enclosed by loop with the tensor intrin.
584 * \param block_rv The block to be tensorized
585 * \param intrin Name of the tensor intrinsic
586 * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
587 */
588 virtual void Tensorize(const BlockRV& block_rv, const String& intrin,
589 bool preserve_unit_iters = true) = 0;
590
591 /******** Schedule: Annotation ********/
592 /*!
593 * \brief Annotate a loop with a key value pair
594 * \param loop_rv The loop to be annotated
595 * \param ann_key The annotation key
596 * \param ann_val The annotation value, a string or a ExprRV
597 */
598 virtual void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) = 0;
599 /*!
600 * \brief Annotate a block with a key value pair
601 * \param block_rv The block to be annotated
602 * \param ann_key The annotation key
603 * \param ann_val The annotation value, a string or a ExprRV
604 */
605 virtual void Annotate(const BlockRV& block_rv, const String& ann_key,
606 const ObjectRef& ann_val) = 0;
607 /*!
608 * \brief Unannotate a loop's annotation with key ann_key
609 * \param loop_rv The loop to be unannotated
610 * \param ann_key The annotation key
611 */
612 virtual void Unannotate(const LoopRV& loop_rv, const String& ann_key) = 0;
613 /*!
614 * \brief Unannotate a block's annotation with key ann_key
615 * \param block_rv The block to be unannotated
616 * \param ann_key The annotation key
617 */
618 virtual void Unannotate(const BlockRV& block_rv, const String& ann_key) = 0;
619
620 /******** Schedule: Layout transformation ********/
621 /*!
622 * \brief Apply a transformation represented by IndexMap to buffer
623 * \details The indices and the access region to the target buffer is transformed by the given
624 * index_map. The index_map is used to infer the new shape of the buffer. Buffer must be either
625 * a function parameter, or allocated in a block (it cannot be a buffer subregion created via
626 * 'match_buffer').
627 * \param block_rv The block that accesses the target buffer.
628 * \param buffer_index The index of the buffer in block's read or write region.
629 * \param buffer_index_type The type of the buffer index, kRead or kWrite.
630 * \param index_map The transformation to apply.
631 *
632 * \param pad_value The value to write into padding introduced by
633 * the transformation. If the schedule contains a producer block
634 * for the specified buffer, the pad value will be written as
635 * part of the producer block if possible, or after the producer
636 * block otherwise. Otherwise, if the buffer is an input, will
637 * insert an annotation block to state that the padding contains
638 * the known value.
639 *
640 * Note: If applied to an input buffer, the calling scope is
641 * responsible for ensuring that the pad_value is present.
642 * Algebraic symplifications, branch elimination, and other
643 * optimizations may assume that this precondition is met, and
644 * may result in incorrect results being returned.
645 */
646 virtual void TransformLayout(const BlockRV& block_rv, int buffer_index,
647 BufferIndexType buffer_index_type, const IndexMap& index_map,
648 const Optional<IndexMap>& pad_value = NullOpt) = 0;
649
650 /*!
651 * \brief Apply a transformation represented by IndexMap to block
652 * \details The block iters and the block body are transformed by the given index_map.
653 * Outer loops corresponding to each new block iter are regenerated.
654 * The index_map is required to be bijective affine since we need its inverse mapping.
655 * \param block_rv The block to be transformed
656 * \param index_map The transformation to apply.
657 */
658 virtual void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) = 0;
659
660 /*!
661 * \brief Set the axis separator of a buffer, where the buffer is specified by a block and a read
662 * or write index
663 * \param block_rv The block that accesses the target buffer.
664 * \param buffer_index The index of the buffer in block's read or write region.
665 * \param buffer_index_type The type of the buffer index, kRead or kWrite.
666 * \param axis_separators The axis separator of the buffer
667 */
668 virtual void SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
669 BufferIndexType buffer_index_type,
670 const Array<IntImm>& axis_separators) = 0;
671
672 /******** Schedule: Padding ********/
673 /*!
674 * \brief Decompose a padding block into a block filling const pad values and a block
675 * writing in-bound values.
676 * \param block_rv The block that match the padding pattern.
677 * \param loop_rv The loop above which the const filling block is inserted before.
678 * \return The const pad value filling block.
679 */
680 virtual BlockRV DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) = 0;
681
682 /*!
683 * \brief Pad the computation of Einsum.
684 * \param block_rv The block that matches the Einsum pattern.
685 * \param padding The padding for each block iter.
686 * \details This schedule primitives identifies the Einsum pattern in the block body, and find its
687 * producer blocks. It then pads the computation of the Einsum pattern and its producer blocks.
688 * The output buffer and the producer buffer is resized according to the padding size. It requires
689 * the output buffer and the producer buffer to be allocated inside the PrimFunc.
690 *
691 * The padding is a list of non-negative integers, each element corresponds to the padding for
692 * each block iter in the order of block iters. The block and its producer blocks should have
693 * trivial bindings, i.e. each block iter is bound to a single loop variable. After padding, the
694 * block iter extent and the corresponding outer loop is extended by the padding size.
695 *
696 * The size of the producer buffers are infered from the padding size of the Einsum computation.
697 * The producer buffers are padded by the initial value of the corresponding reduction.
698 */
699 virtual void PadEinsum(const BlockRV& block_rv, const Array<Integer>& padding) = 0;
700
701 /******** Schedule: Buffer transformation ********/
702 /*!
703 * \brief Compute the target buffer via rolling buffering.
704 * \details This primitive selects the outermost rollable axis with a positive bound overlap that
705 * appears in the block's ancestor loops as `rolling axis`, fold and circularize the buffer along
706 * the rolling dimension, append block predicate to avoid recomputing overlapping elements.
707 * It requires:
708 * 1) The buffer to be an intermediate buffer defined via `alloc_buffer`.
709 * 2) The LCA of the producer and consumer of the buffer is a for loop, typically,
710 * the producer and consumer of the buffer are cascaded through compute_at.
711 * 3) The access region of the buffer has at least one dimension that contains
712 * a positive bound overlap.
713 * \param block_rv The producer block of the buffer.
714 * \param write_buffer_index The index of the buffer in block's write region.
715 */
716 virtual void RollingBuffer(const BlockRV& block_rv, int write_buffer_index) = 0;
717
718 /******** Schedule: Misc ********/
719 /*! \brief A no-op that marks the start of postprocessing phase of scheduling */
720 virtual void EnterPostproc() = 0;
721};
722
723/*!
724 * \brief Managed reference to ScheduleNode
725 *
726 * A schedule is a set of transformations that change the order of computation but
727 * preserve the semantics of computation. Some example of schedules:
728 * 1) Split a loop into two;
729 * 2) Reorder two loops;
730 * 3) Inline the computation of a specific buffer into its consumer
731 *
732 * The schedule class stores auxiliary information to schedule correctly and efficiently.
733 *
734 * Link to tutorial: https://tvm.apache.org/docs/tutorials/language/schedule_primitives.html
735 *
736 * \sa ScheduleNode
737 */
738class Schedule : public runtime::ObjectRef {
739 public:
740 /*!
741 * \brief Construct a concrete TensorIR schedule from an IRModule
742 * \param mod The IRModule to be scheduled
743 * \param seed The seed value for schedule's random state
744 * \param debug_mask Do extra correctness checking after the class creation
745 * and each time after calling the Replace method.
746 * \param error_render_level The level of error rendering
747 * \return The concrete schedule created
748 * \sa ScheduleDebugMask
749 * \note The checks performed includes:
750 * 1) VerifySRefTree
751 * 2) VerifyCachedFlags
752 */
753 TVM_DLL static Schedule Concrete(IRModule mod, support::LinearCongruentialEngine::TRandState seed,
754 int debug_mask, ScheduleErrorRenderLevel error_render_level);
755 /*!
756 * \brief Construct a traced concrete TensorIR schedule from an IRModule
757 * \param mod The IRModule to be scheduled
758 * \param seed The seed value for schedule's random state
759 * \param debug_mask Do extra correctness checking after the class creation
760 * and each time after calling the Replace method.
761 * \param error_render_level The level of error rendering
762 * \return The concrete schedule created
763 * \sa ScheduleDebugMask
764 * \note The checks performed include:
765 * 1) VerifySRefTree
766 * 2) VerifyCachedFlags
767 */
768 TVM_DLL static Schedule Traced(IRModule mod, support::LinearCongruentialEngine::TRandState seed,
769 int debug_mask, ScheduleErrorRenderLevel error_render_level);
770 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef, ScheduleNode);
771};
772
773} // namespace tir
774} // namespace tvm
775
776#endif // TVM_TIR_SCHEDULE_SCHEDULE_H_
777