1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#ifndef TVM_TIR_SCHEDULE_PRIMITIVE_H_
20#define TVM_TIR_SCHEDULE_PRIMITIVE_H_
21
22#include <tvm/support/random_engine.h>
23#include <tvm/tir/schedule/state.h>
24
25#include <vector>
26
27namespace tvm {
28namespace tir {
29
30/******** Schedule: Sampling ********/
31/*!
32 * \brief Sample a random integer from a given range.
33 * \param rand_state The pointer to schedule's random state.
34 * \param min_inclusive The minimum value of the range, inclusive.
35 * \param max_exclusive The maximum value of the range, exclusive.
36 * \return The random integer sampled in the given range.
37 */
38TVM_DLL int32_t SampleInt(support::LinearCongruentialEngine::TRandState* rand_state,
39 int32_t min_inclusive, int32_t max_exclusive);
40/*!
41 * \brief Sample k random integers from given range without replacement, i.e, no duplication.
42 * \param rand_state The pointer to schedule's random state
43 * \param n The range is defined as 0 to n-1.
44 * \param k The total number of samples.
45 * \return The randomly selected samples from the n candidates.
46 */
47std::vector<int32_t> SampleWithoutReplacement(
48 support::LinearCongruentialEngine::TRandState* rand_state, int32_t n, int32_t k);
49/*!
50 * \brief Sample once category from candidates according to the probability weights.
51 * \param rand_state The pointer to schedule's random state
52 * \param candidates The candidates
53 * \param probs The probability distribution of the candidates
54 * \param decision The sampling decision, if any
55 * \return The random variable sampled from candidates
56 */
57TVM_DLL int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state,
58 const Array<Integer>& candidates, const Array<FloatImm>& probs,
59 Optional<Integer>* decision);
60/*!
61 * \brief Create a sampling function that does multinomial sampling.
62 * \param rand_state The random state.
63 * \param weights The weights for multinomial sampling.
64 * \return The multinomial sampling function.
65 */
66TVM_DLL std::function<int32_t()> MakeMultinomialSampler(
67 support::LinearCongruentialEngine::TRandState* rand_state, const std::vector<double>& weights);
68/*!
69 * \brief Sample the factors to perfect tile a specific loop
70 * \param rand_state The random state
71 * \param extent The loop extent to be tiled
72 * \param n_split The number of tiles to be sampled
73 * \return A list of length `n`, the random perfect tile sizes sampled
74 */
75TVM_DLL std::vector<int64_t> SamplePerfectTile(
76 support::LinearCongruentialEngine::TRandState* rand_state, //
77 int32_t extent, int32_t n_splits);
78/*!
79 * \brief Sample the factors to perfect tile a specific loop
80 * \param rand_state The random state
81 * \param extent The loop extent to be tiled
82 * \param n_split The number of tiles to be sampled
83 * \param max_innermost_factor The maximum tile size allowed to be sampled in the innermost loop
84 * \return A list of length `n`, the random perfect tile sizes sampled
85 */
86TVM_DLL std::vector<int64_t> SamplePerfectTile(
87 support::LinearCongruentialEngine::TRandState* rand_state, //
88 int32_t extent, int32_t n_split, int32_t max_innermost_factor);
89/*!
90 * \brief Sample the factors to perfect tile a specific loop
91 * \param rand_state The random state
92 * \param loop_sref The loop to be tiled
93 * \param n_split The number of tiles to be sampled
94 * \param max_innermost_factor The maximum tile size allowed to be sampled in the innermost loop
95 * \param decision The sampling decision
96 * \return A list of length `n`, the random perfect tile sizes sampled
97 */
98TVM_DLL std::vector<int64_t> SamplePerfectTile(
99 support::LinearCongruentialEngine::TRandState* rand_state, //
100 const tir::StmtSRef& loop_sref, int32_t n_split, int32_t max_innermost_factor,
101 Optional<Array<Integer>>* decision);
102/*!
103 * \brief Sample a compute-at location of the given block
104 * \param self The schedule state
105 * \param rand_state The random state
106 * \param block_sref The sref of the block whose compute-at location is to be sampled
107 * \param decision The sampling decision
108 * \return The sampled loop where the input block is to be computed at
109 */
110TVM_DLL tir::StmtSRef SampleComputeLocation(
111 tir::ScheduleState self, support::LinearCongruentialEngine::TRandState* rand_state,
112 const tir::StmtSRef& block_sref, Optional<Integer>* decision);
113
114/******** Schedule: Get blocks & loops ********/
115/*!
116 * \brief Retrieves blocks in a specific function with its name
117 * \param self The schedule state
118 * \param name The name of the blocks to be retrieved
119 * \param gvar The function to be retrieved
120 * \return A list of blocks with the specific name
121 */
122Array<StmtSRef> GetBlocks(const ScheduleState& self, const String& name, const GlobalVar& gv);
123/*!
124 * \brief Gets the parent loops of the block in its scope, from outer to inner
125 * \param self The schedule state
126 * \param block_sref The query block
127 * \return A list of loops above the given block in its scope, from outer to inner
128 */
129Array<StmtSRef> GetLoops(const StmtSRef& block_sref);
130/*!
131 * \brief Get the leaf blocks of a specific block/loop
132 * \param self The schedule state
133 * \param parent_sref The query block/loop
134 * \return A list of leaf blocks inside a specific block/loop
135 */
136Array<StmtSRef> GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref);
137/*!
138 * \brief Get the producers of a specific block
139 * \param self The schedule state
140 * \param block_sref The block in the query
141 * \return A list of blocks, the producers of the given block
142 */
143Array<StmtSRef> GetProducers(const ScheduleState& self, const StmtSRef& block_sref);
144/*!
145 * \brief Get the consumers of a specific block
146 * \param self The schedule state
147 * \param block_rv The block in the query
148 * \return A list of blocks, the consumers of the given block
149 */
150Array<StmtSRef> GetConsumers(const ScheduleState& self, const StmtSRef& block_sref);
151/******** Schedule: Transform loops ********/
152/*!
153 * Split a loop into a list of consecutive loops. It requires:
154 * 1) The loop can't have annotation or thread binding.
155 * 2) The loop must start with 0.
156 * \param self The state of the schedule
157 * \param loop_sref The sref to the loop being split
158 * \param factors The splitting factors
159 * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
160 * \return An array of srefs to the loops after splitting
161 */
162TVM_DLL Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
163 const Array<PrimExpr>& factors, bool preserve_unit_iters);
164/*!
165 * \brief Fuse a list of consecutive loops into one. It requires:
166 * 1) The loops can't have annotations or thread bindings.
167 * 2) The inner loop must be the only child of the outer loop.
168 * 3) All loops must start with 0.
169 * 4) The domain of a loop to be fused cannot depend on another loop to be fused.
170 * \param self The state of the schedule
171 * \param loop_srefs An array of srefs to the loops to be fused
172 * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
173 * \return The sref to the fused loop
174 */
175TVM_DLL StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs,
176 bool preserve_unit_loops);
177/*!
178 * \brief Reorder a list of loops. It doesn't require the loops to be consecutive.
179 * It requires:
180 * 1) The loops are in the same chain. That means: the loops can be ordered to [l_1, l_2, ... ,
181 * l_n] where l_i is an ancestor of l_{i+1} and there are only single-branch loops between
182 * l_1 and l_n (which also indicates they are under the same scope).
183 * 2) After reordering, the domain of an outer loop cannot depend on any of the inner loops.
184 * 3) For every block under the loop nests, its block binding must be affine, and the block
185 * variables must be either data parallel or reduction.
186 * 4) No duplicated loops are allowed in the arguments.
187 * \param self The state of the schedule
188 * \param ordered_loop_srefs An array of srefs which indicates the new order of loops
189 */
190TVM_DLL void Reorder(ScheduleState self, const Array<StmtSRef>& ordered_loop_srefs);
191
192/*!
193 * \brief Create a new unit loop on top of the specific block or loop.
194 * \param sref The block/loop above which the new thread_binding loop is created
195 * \param extent The extent of the new thread_binding loop
196 * \param thread_axis The thread axis of the new thread_binding loop
197 * \param attrs Extra loop attributes
198 * \return The new thread_binding loop
199 */
200TVM_DLL StmtSRef AddUnitLoop(ScheduleState self, StmtSRef sref);
201
202/******** Schedule: Manipulate ForKind ********/
203/*!
204 * \brief Parallelize the input loop. It requires:
205 * 1) The scope block that the loop is in should have stage-pipeline property
206 * 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine
207 * bindings
208 * 3) For each block under the loop, the loop can only be contained in data-parallel block iters'
209 * bindings
210 * \param self The state of the schedule
211 * \param loop_sref The sref of the loop to be parallelized
212 */
213TVM_DLL void Parallel(ScheduleState self, const StmtSRef& loop_sref);
214/*!
215 * \brief Vectorize the input loop. It requires:
216 * 1) The scope block that the loop is in should have stage-pipeline property
217 * 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine
218 * bindings
219 * 3) For each block under the loop, the loop can only be contained in data-parallel block iters'
220 * bindings
221 * \param self The state of the schedule
222 * \param loop_sref The sref of the loop to be vectorized
223 */
224TVM_DLL void Vectorize(ScheduleState self, const StmtSRef& loop_sref);
225/*!
226 * \brief Bind the input loop to the given thread axis. It requires:
227 * 1) The scope block that the loop is in should have stage-pipeline property
228 * 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine
229 * bindings
230 * 3) For each block under the loop, if the thread axis starts with "threadIdx`, the loop can only
231 * be contained in data-parallel block iter and reduction block iters' bindings. Otherwise the
232 * loop can only be contained in data-parallel block iters' bindings
233 * \param self The state of the schedule
234 * \param loop_sref The sref of the loop to be bound to the thread axis
235 * \param thread_axis The thread axis to be bound to the loop
236 */
237TVM_DLL void Bind(ScheduleState self, const StmtSRef& loop_sref, const IterVar& thread_axis);
238/*!
239 * \brief Unroll the input loop. It requires nothing
240 * \param self The state of the schedule
241 * \param loop_sref The loop to be unrolled
242 */
243TVM_DLL void Unroll(ScheduleState self, const StmtSRef& loop_sref);
244/******** Schedule: Insert cache stages ********/
245/*!
246 * \brief Create a block that reads a buffer region into a read cache. It requires:
247 * 1) There is at most one block who writes the buffer in the scope.
248 * 2) The scope block have stage-pipeline property.
249 * \param self The state of the schedule
250 * \param block_sref The consumer block of the target buffer.
251 * \param read_buffer_index The index of the buffer in block's read region.
252 * \param storage_scope The target storage scope.
253 * \param consumer_blocks Array of blocks that consume the cache.
254 * \return The cache stage block.
255 */
256TVM_DLL StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index,
257 const String& storage_scope, const Array<StmtSRef> consumer_blocks = {});
258/*!
259 * \brief Create a block that writes a buffer region into a write cache. It requires:
260 * 1) There is only one block that writes the target buffer.
261 * 2) The scope block have stage-pipeline property.
262 * \param self The state of the schedule
263 * \param block_sref The producer of the buffer
264 * \param write_buffer_index The index of the buffer in block's write region
265 * \param storage_scope The target storage scope
266 * \param consumer_blocks Array of blocks that consume the cache.
267 * \return The cache stage block.
268 */
269TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index,
270 const String& storage_scope,
271 const Array<StmtSRef> consumer_blocks = {});
272/*!
273 *!
274 * \brief Create 2 blocks that read&write a buffer region into a read/write cache.
275 * It requires the the target block both read & write the target buffer.
276 * \param self The state of the schedule
277 * \param block_sref The target block operates on the target buffer.
278 * \param read_buffer_index The index of the buffer in block's read region.
279 * \param storage_scope The target storage scope
280 * \return The cache stage blocks, cache read block together with cache write block.
281 */
282TVM_DLL Array<StmtSRef> CacheInplace(ScheduleState self, const StmtSRef& block_sref,
283 int read_buffer_index, const String& storage_scope);
284/*!
285 * \brief Create a block to cache precomputed index for later use.
286 * if there is no index computation, keep unchanged.
287 * \param block_sref The target block
288 * \param storage_scope The storage scope of cached block
289 * \param cse_thresh The repeat threshold that determines a common sub expr
290 * \return The cache stage block.
291 */
292TVM_DLL Array<StmtSRef> CacheIndex(ScheduleState self, const StmtSRef& block_sref,
293 const String& storage_scope, int cse_thresh);
294/*!
295 *!
296 * \brief Create a block that read/write a buffer region into a read/write cache with reindexing.
297 * The layout of the cache will be the same as by the iterators of the block that reads/writes the
298 * buffer. It requires:
299 * 1) There is only one block who reads/writes the target buffer
300 * 2) There is only one buffer load/store of this buffer in the block
301 * \param self The state of the schedule
302 * \param block_sref The block operates on the target buffer.
303 * \param buffer_index The index of the buffer in block's read or write region.
304 * \param buffer_index_type The type of the buffer index, kRead or kWrite.
305 * \return The reindex stage block.
306 */
307TVM_DLL StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
308 BufferIndexType buffer_index_type);
309/******** Schedule: Compute location ********/
310/*!
311 * \brief Move a producer block under the specific loop, and regenerate the
312 * loops induced by the block so that the buffer region produced by the producer block could
313 * cover those regions consumed by its consumer blocks under the given loop. It requires:
314 * 1) `block` and `loop` are under the same scope, `loop` is not the ancestor of `block`
315 * 2) The scope block has stage-pipeline property
316 * 3) The subtree of the scope block, where the given block is in, satisfies the compact dataflow
317 * condition. i.e. all the blocks in the scope block's subtree must be either complete block or
318 * reduction block
319 * 4) The block is not an output block with regard to the scope block, i.e. the buffers written by
320 * the block are allocated under the scope block
321 * 5) All the consumers of the block are under the given loop
322 *
323 * \param self The schedule state
324 * \param block_sref The block to be moved
325 * \param loop_sref The loop where the block to be moved to
326 * \param index The block index of the loop body subtree blocks:
327 * - `index = -1` means inserted into the last possible insertion point;
328 * - `index = -2` means inserted into the first possible insertion point;
329 * - Otherwise, `index` is a nonnegative number that indicates the insertion point
330 */
331TVM_DLL void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
332 bool preserve_unit_loops, int index = -1);
333/*!
334 * \brief Move a consumer block under the specific loop, and regenerate the
335 * loops induced by the block so that the buffer region consumed by the consumer block could
336 * cover those regions produced by its producer blocks under the given loop. It requires:
337 * 1) `block` and `loop` are under the same scope, `loop` is not the ancestor of `block`
338 * 2) The scope block has stage-pipeline property
339 * 3) The subtree of the scope block, where the given block is in, satisfies the compact dataflow
340 * condition. i.e. all the blocks in the scope block's subtree must be either complete block or
341 * reduction block
342 * 4) All the producers of the block are under the given loop
343 *
344 * \param self The schedule state
345 * \param block_sref The block to be moved
346 * \param loop_sref The loop where the block to be moved to
347 * \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1
348 * \param index The block index of the loop body subtree blocks:
349 * - `index = -1` means inserted into the last possible insertion point;
350 * - `index = -2` means inserted into the first possible insertion point;
351 * - Otherwise, `index` is a nonnegative number that indicates the insertion point
352 */
353TVM_DLL void ReverseComputeAt(ScheduleState self, const StmtSRef& block_sref,
354 const StmtSRef& loop_sref, bool preserve_unit_loops, int index = -1);
355/*!
356 * \brief Inline a block into its consumer(s). It requires:
357 * 1) The block is a complete non-root block, which only produces one buffer
358 * 2) The block must not be the only leaf in the scope.
359 * 3) The body of the block must be a BufferStore statement in the form of,
360 * A[i, j, k, ...] = ...
361 * where the indices of the LHS are all distinct atomic variables,
362 * and no variables other than those indexing variables are allowed in the statement.
363 * \param self The state of the schedule
364 * \param block_sref The sref to the block to be inlined to its consumer(s)
365 */
366TVM_DLL void ComputeInline(ScheduleState self, const StmtSRef& block_sref);
367/*!
368 * \brief Inline a block into its only producer. It requires:
369 * 1) The block is a complete non-root block, which only produces and consumers one buffer
370 * 2) The block must not be the only leaf in the scope.
371 * 3) The only producer of the block is a read-after-write producer and a complete non-root block
372 * 4) The body of the block must be a BufferStore statement in the form of,
373 * B[f(i, j, k, ...)] = g(i, j, k, A[i, j, k, ...] ...)
374 * where the indices of each `BufferLoad` on the RHS are all distinct atomic variables,
375 * and no variables other than those indexing variables are allowed in the statement.
376 * \param self The state of the schedule
377 * \param block_sref The sref to the block to be inlined to its producer
378 */
379TVM_DLL void ReverseComputeInline(ScheduleState self, const StmtSRef& block_sref);
380/******** Schedule: Reduction ********/
381/*!
382 * \brief Decompose a reduction block into two separate blocks.
383 * a) The init block, which is translated from the init statement of the reduction block;
384 * b) The update block, which is the original block without init statement.
385 *
386 * The init block is inserted right before the given loop.
387 *
388 * The schedule primitive requires:
389 * 1) The input block is a reduction block.
390 * 2) The input loop is the ancestor of the block.
391 * 3) The input loop is not lower than all the loops related to reduce block var.
392 * \param block_rv The reduction block to be decomposed
393 * \param loop_rv The loop above which the init block is inserted before.
394 * \return The init block
395 */
396TVM_DLL StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
397 const StmtSRef& loop_sref);
398/*!
399 * \brief Factor a reduction block by the specified loop
400 * \details See python/tvm/tir/schedule/schedule.py
401 * \param self The state of the schedule
402 * \param loop_sref The loop outside block for which we want to do rfactor
403 * \param factor_axis The position where the new dimension is placed in the new introduced rfactor
404 * buffer. Suppose the original reduction block writes to buffer `B` with
405 * ndim(B) dimensions, then `factor_axis` should be in range `[-ndim(B) - 1,
406 * ndim(B)]`, and the negative index will be normalized to a non-negative one
407 * \return The sref of the rfactor block
408 */
409TVM_DLL StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int factor_axis);
410/******** Schedule: Block annotation ********/
411/*! \brief The quad used by StorageAlign for (buffer_idx, axis, factor, offset) */
412using StorageAlignTuple = Array<Integer>;
413/*! \brief A list of StorageAlignTuple, used by StorageAlign */
414using StorageAlignAnnotation = Array<StorageAlignTuple>;
415/*!
416 * \brief Set alignment requirement for specific dimension such that
417 * stride[axis] == k * factor + offset for some k. This is useful to set memory layout for
418 * more friendly memory access pattern. For example, we can set alignment to be factor=2,
419 * offset=1 to avoid bank conflict for thread access on higher dimension in GPU shared
420 * memory.
421 * \param self The state of the schedule
422 * \param block_sref The producer block of the buffer
423 * \param buffer_index The index of the buffer in block's write region
424 * \param axis The dimension to be specified for alignment
425 * \param factor The factor multiple of alignment
426 * \param offset The required offset factor
427 */
428TVM_DLL void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
429 int axis, int factor, int offset);
430/*!
431 * \brief Set the storage scope of a buffer, where the buffer is specified by the a block and a
432 * write-index
433 * \param self The state of the schedule
434 * \param block_sref The sref of the producer block of the buffer
435 * \param buffer_index The index of the buffer in block's write region
436 * \param storage_scope The storage scope to be set
437 */
438TVM_DLL void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
439 const String& storage_scope);
440/*!
441 * \brief Set the axis separator of a buffer, where the buffer is specified by a block and a read
442 * or write index
443 * \param block_rv The block that accesses the target buffer.
444 * \param buffer_index The index of the buffer in block's read or write region.
445 * \param buffer_index_type The type of the buffer index, kRead or kWrite.
446 * \param axis_separators The axis separator of the buffer
447 */
448TVM_DLL void SetAxisSeparator(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
449 BufferIndexType buffer_index_type,
450 const Array<IntImm>& axis_separators);
451
452/******** Schedule: Blockize & Tensorize ********/
453
454/*!
455 * \brief Convert the subtree rooted at a specific loop into a block.
456 * \param self The state of the schedule
457 * \param loop_sref The root of the subtree
458 * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
459 * \return The new block
460 */
461TVM_DLL StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref, bool preserve_unit_iters);
462
463/*!
464 * \brief Tensorize the computation enclosed by loop with the tensor intrinsic.
465 * \param self The state of the schedule
466 * \param block_or_loop_sref The block or loop to be tensorized.
467 * \param intrin The tensor intrinsic.
468 * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
469 */
470TVM_DLL void Tensorize(ScheduleState self, const StmtSRef& block_or_loop_sref,
471 const TensorIntrin& intrin, bool preserve_unit_iters);
472
473/******** Schedule: Annotation ********/
474/*!
475 * \brief Annotate a block/loop with a key value pair
476 * \param self The state of the schedule
477 * \param sref The block/loop sref to be annotated
478 * \param ann_key The annotation key
479 * \param ann_val The annotation value
480 */
481TVM_DLL void Annotate(ScheduleState self, const StmtSRef& sref, const String& ann_key,
482 const ObjectRef& ann_val);
483/*!
484 * \brief Unannotate a block/loop's annotation with key ann_key
485 * \param self The state of the schedule
486 * \param sref The block/loop to be unannotated
487 * \param ann_key The annotation key
488 */
489TVM_DLL void Unannotate(ScheduleState self, const StmtSRef& sref, const String& ann_key);
490
491/******** Schedule: Layout transformation ********/
492/*!
493 * \brief Apply a transformation represented by IndexMap to buffer
494 * \details The indices and the access region to the target buffer is transformed by the given
495 * index_map. The index_map is also used to infer the new shape of the buffer. Buffer must be
496 * one of the parameter of the function, or allocated in some blocks (it cannot be a buffer
497 * subregion created via match_buffer).
498 * \param self The state of the schedule
499 * \param block_sref The block sref that accesses the target buffer.
500 * \param buffer_index The index of the buffer in block's read or write region.
501 * \param buffer_index_type The type of the buffer index, kRead or kWrite.
502 * \param index_map The transformation to apply.
503 * \param pad_value The value to write into padding introduced by the transformation.
504 */
505TVM_DLL void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
506 BufferIndexType buffer_index_type, const IndexMap& index_map,
507 const Optional<IndexMap>& pad_value);
508
509/*!
510 * \brief Apply a transformation represented by IndexMap to block
511 * \details The block iters and the block body are transformed by the given index_map.
512 * Outer loops corresponding to each new block iter are regenerated.
513 * The index_map is required to be bijective affine since we need its inverse mapping.
514 * \param self The state of the schedule
515 * \param block_sref The block sref that refers to the block to be transformed
516 * \param index_map The transformation to apply.
517 */
518TVM_DLL void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref,
519 const IndexMap& index_map);
520
521/******** Schedule: Padding ********/
522/*!
523 * \brief Decompose a padding block into a block filling const pad values and a block
524 * writing in-bound values.
525 * \param block_sref The block sref that match the padding pattern.
526 * \param loop_sref The loop above which the const filling block is inserted before.
527 * \return The padding value filling block sref.
528 */
529TVM_DLL StmtSRef DecomposePadding(ScheduleState self, const StmtSRef& block_sref,
530 const StmtSRef& loop_sref);
531
532/*!
533 * \brief Pad the computation of Einsum.
534 * \param self The state of the schedule
535 * \param block_sref The block sref that matches the Einsum pattern.
536 * \param padding The padding for each block iter.
537 */
538TVM_DLL void PadEinsum(ScheduleState self, const StmtSRef& block_sref,
539 const Array<Integer>& padding);
540
541/******** Schedule: Buffer transformation ********/
542/*!
543 * \brief Compute the target buffer via rolling buffering.
544 * \details This primitive selects the outermost rollable axis with a positive bound overlap that
545 * appears in the block's ancestor loops as `rolling axis`, fold and circularize the buffer along
546 * the rolling dimension, append block predicate to avoid recomputing overlapping elements.
547 * It requires:
548 * 1) The buffer to be an intermediate buffer defined via `alloc_buffer`.
549 * 2) The LCA of the producer and consumer of the buffer is a for loop, typically,
550 * the producer and consumer of the buffer are cascaded through compute_at.
551 * 3) The access region of the buffer has at least one dimension that contains
552 * a positive bound overlap.
553 * \param block_rv The producer block of the buffer.
554 * \param write_buffer_index The index of the buffer in block's write region.
555 */
556TVM_DLL void RollingBuffer(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index);
557/******** Schedule: Misc ********/
558
559} // namespace tir
560} // namespace tvm
561
562#endif // TVM_TIR_SCHEDULE_PRIMITIVE_H_
563