1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #ifndef TVM_TIR_SCHEDULE_PRIMITIVE_H_ |
20 | #define TVM_TIR_SCHEDULE_PRIMITIVE_H_ |
21 | |
22 | #include <tvm/support/random_engine.h> |
23 | #include <tvm/tir/schedule/state.h> |
24 | |
25 | #include <vector> |
26 | |
27 | namespace tvm { |
28 | namespace tir { |
29 | |
30 | /******** Schedule: Sampling ********/ |
31 | /*! |
32 | * \brief Sample a random integer from a given range. |
33 | * \param rand_state The pointer to schedule's random state. |
34 | * \param min_inclusive The minimum value of the range, inclusive. |
35 | * \param max_exclusive The maximum value of the range, exclusive. |
36 | * \return The random integer sampled in the given range. |
37 | */ |
38 | TVM_DLL int32_t SampleInt(support::LinearCongruentialEngine::TRandState* rand_state, |
39 | int32_t min_inclusive, int32_t max_exclusive); |
40 | /*! |
41 | * \brief Sample k random integers from given range without replacement, i.e, no duplication. |
42 | * \param rand_state The pointer to schedule's random state |
43 | * \param n The range is defined as 0 to n-1. |
44 | * \param k The total number of samples. |
45 | * \return The randomly selected samples from the n candidates. |
46 | */ |
47 | std::vector<int32_t> SampleWithoutReplacement( |
48 | support::LinearCongruentialEngine::TRandState* rand_state, int32_t n, int32_t k); |
49 | /*! |
50 | * \brief Sample once category from candidates according to the probability weights. |
51 | * \param rand_state The pointer to schedule's random state |
52 | * \param candidates The candidates |
53 | * \param probs The probability distribution of the candidates |
54 | * \param decision The sampling decision, if any |
55 | * \return The random variable sampled from candidates |
56 | */ |
57 | TVM_DLL int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state, |
58 | const Array<Integer>& candidates, const Array<FloatImm>& probs, |
59 | Optional<Integer>* decision); |
60 | /*! |
61 | * \brief Create a sampling function that does multinomial sampling. |
62 | * \param rand_state The random state. |
63 | * \param weights The weights for multinomial sampling. |
64 | * \return The multinomial sampling function. |
65 | */ |
66 | TVM_DLL std::function<int32_t()> MakeMultinomialSampler( |
67 | support::LinearCongruentialEngine::TRandState* rand_state, const std::vector<double>& weights); |
68 | /*! |
69 | * \brief Sample the factors to perfect tile a specific loop |
70 | * \param rand_state The random state |
71 | * \param extent The loop extent to be tiled |
72 | * \param n_split The number of tiles to be sampled |
73 | * \return A list of length `n`, the random perfect tile sizes sampled |
74 | */ |
75 | TVM_DLL std::vector<int64_t> SamplePerfectTile( |
76 | support::LinearCongruentialEngine::TRandState* rand_state, // |
77 | int32_t extent, int32_t n_splits); |
78 | /*! |
79 | * \brief Sample the factors to perfect tile a specific loop |
80 | * \param rand_state The random state |
81 | * \param extent The loop extent to be tiled |
82 | * \param n_split The number of tiles to be sampled |
83 | * \param max_innermost_factor The maximum tile size allowed to be sampled in the innermost loop |
84 | * \return A list of length `n`, the random perfect tile sizes sampled |
85 | */ |
86 | TVM_DLL std::vector<int64_t> SamplePerfectTile( |
87 | support::LinearCongruentialEngine::TRandState* rand_state, // |
88 | int32_t extent, int32_t n_split, int32_t max_innermost_factor); |
89 | /*! |
90 | * \brief Sample the factors to perfect tile a specific loop |
91 | * \param rand_state The random state |
92 | * \param loop_sref The loop to be tiled |
93 | * \param n_split The number of tiles to be sampled |
94 | * \param max_innermost_factor The maximum tile size allowed to be sampled in the innermost loop |
95 | * \param decision The sampling decision |
96 | * \return A list of length `n`, the random perfect tile sizes sampled |
97 | */ |
98 | TVM_DLL std::vector<int64_t> SamplePerfectTile( |
99 | support::LinearCongruentialEngine::TRandState* rand_state, // |
100 | const tir::StmtSRef& loop_sref, int32_t n_split, int32_t max_innermost_factor, |
101 | Optional<Array<Integer>>* decision); |
102 | /*! |
103 | * \brief Sample a compute-at location of the given block |
104 | * \param self The schedule state |
105 | * \param rand_state The random state |
106 | * \param block_sref The sref of the block whose compute-at location is to be sampled |
107 | * \param decision The sampling decision |
108 | * \return The sampled loop where the input block is to be computed at |
109 | */ |
110 | TVM_DLL tir::StmtSRef SampleComputeLocation( |
111 | tir::ScheduleState self, support::LinearCongruentialEngine::TRandState* rand_state, |
112 | const tir::StmtSRef& block_sref, Optional<Integer>* decision); |
113 | |
114 | /******** Schedule: Get blocks & loops ********/ |
115 | /*! |
116 | * \brief Retrieves blocks in a specific function with its name |
117 | * \param self The schedule state |
118 | * \param name The name of the blocks to be retrieved |
119 | * \param gvar The function to be retrieved |
120 | * \return A list of blocks with the specific name |
121 | */ |
122 | Array<StmtSRef> GetBlocks(const ScheduleState& self, const String& name, const GlobalVar& gv); |
123 | /*! |
124 | * \brief Gets the parent loops of the block in its scope, from outer to inner |
125 | * \param self The schedule state |
126 | * \param block_sref The query block |
127 | * \return A list of loops above the given block in its scope, from outer to inner |
128 | */ |
129 | Array<StmtSRef> GetLoops(const StmtSRef& block_sref); |
130 | /*! |
131 | * \brief Get the leaf blocks of a specific block/loop |
132 | * \param self The schedule state |
133 | * \param parent_sref The query block/loop |
134 | * \return A list of leaf blocks inside a specific block/loop |
135 | */ |
136 | Array<StmtSRef> GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref); |
137 | /*! |
138 | * \brief Get the producers of a specific block |
139 | * \param self The schedule state |
140 | * \param block_sref The block in the query |
141 | * \return A list of blocks, the producers of the given block |
142 | */ |
143 | Array<StmtSRef> GetProducers(const ScheduleState& self, const StmtSRef& block_sref); |
144 | /*! |
145 | * \brief Get the consumers of a specific block |
146 | * \param self The schedule state |
147 | * \param block_rv The block in the query |
148 | * \return A list of blocks, the consumers of the given block |
149 | */ |
150 | Array<StmtSRef> GetConsumers(const ScheduleState& self, const StmtSRef& block_sref); |
151 | /******** Schedule: Transform loops ********/ |
152 | /*! |
153 | * Split a loop into a list of consecutive loops. It requires: |
154 | * 1) The loop can't have annotation or thread binding. |
155 | * 2) The loop must start with 0. |
156 | * \param self The state of the schedule |
157 | * \param loop_sref The sref to the loop being split |
158 | * \param factors The splitting factors |
159 | * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings |
160 | * \return An array of srefs to the loops after splitting |
161 | */ |
162 | TVM_DLL Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref, |
163 | const Array<PrimExpr>& factors, bool preserve_unit_iters); |
164 | /*! |
165 | * \brief Fuse a list of consecutive loops into one. It requires: |
166 | * 1) The loops can't have annotations or thread bindings. |
167 | * 2) The inner loop must be the only child of the outer loop. |
168 | * 3) All loops must start with 0. |
169 | * 4) The domain of a loop to be fused cannot depend on another loop to be fused. |
170 | * \param self The state of the schedule |
171 | * \param loop_srefs An array of srefs to the loops to be fused |
172 | * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings |
173 | * \return The sref to the fused loop |
174 | */ |
175 | TVM_DLL StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs, |
176 | bool preserve_unit_loops); |
177 | /*! |
178 | * \brief Reorder a list of loops. It doesn't require the loops to be consecutive. |
179 | * It requires: |
180 | * 1) The loops are in the same chain. That means: the loops can be ordered to [l_1, l_2, ... , |
181 | * l_n] where l_i is an ancestor of l_{i+1} and there are only single-branch loops between |
182 | * l_1 and l_n (which also indicates they are under the same scope). |
183 | * 2) After reordering, the domain of an outer loop cannot depend on any of the inner loops. |
184 | * 3) For every block under the loop nests, its block binding must be affine, and the block |
185 | * variables must be either data parallel or reduction. |
186 | * 4) No duplicated loops are allowed in the arguments. |
187 | * \param self The state of the schedule |
188 | * \param ordered_loop_srefs An array of srefs which indicates the new order of loops |
189 | */ |
190 | TVM_DLL void Reorder(ScheduleState self, const Array<StmtSRef>& ordered_loop_srefs); |
191 | |
192 | /*! |
193 | * \brief Create a new unit loop on top of the specific block or loop. |
194 | * \param sref The block/loop above which the new thread_binding loop is created |
195 | * \param extent The extent of the new thread_binding loop |
196 | * \param thread_axis The thread axis of the new thread_binding loop |
197 | * \param attrs Extra loop attributes |
198 | * \return The new thread_binding loop |
199 | */ |
200 | TVM_DLL StmtSRef AddUnitLoop(ScheduleState self, StmtSRef sref); |
201 | |
202 | /******** Schedule: Manipulate ForKind ********/ |
203 | /*! |
204 | * \brief Parallelize the input loop. It requires: |
205 | * 1) The scope block that the loop is in should have stage-pipeline property |
206 | * 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine |
207 | * bindings |
208 | * 3) For each block under the loop, the loop can only be contained in data-parallel block iters' |
209 | * bindings |
210 | * \param self The state of the schedule |
211 | * \param loop_sref The sref of the loop to be parallelized |
212 | */ |
213 | TVM_DLL void Parallel(ScheduleState self, const StmtSRef& loop_sref); |
214 | /*! |
215 | * \brief Vectorize the input loop. It requires: |
216 | * 1) The scope block that the loop is in should have stage-pipeline property |
217 | * 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine |
218 | * bindings |
219 | * 3) For each block under the loop, the loop can only be contained in data-parallel block iters' |
220 | * bindings |
221 | * \param self The state of the schedule |
222 | * \param loop_sref The sref of the loop to be vectorized |
223 | */ |
224 | TVM_DLL void Vectorize(ScheduleState self, const StmtSRef& loop_sref); |
225 | /*! |
226 | * \brief Bind the input loop to the given thread axis. It requires: |
227 | * 1) The scope block that the loop is in should have stage-pipeline property |
228 | * 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine |
229 | * bindings |
230 | * 3) For each block under the loop, if the thread axis starts with "threadIdx`, the loop can only |
231 | * be contained in data-parallel block iter and reduction block iters' bindings. Otherwise the |
232 | * loop can only be contained in data-parallel block iters' bindings |
233 | * \param self The state of the schedule |
234 | * \param loop_sref The sref of the loop to be bound to the thread axis |
235 | * \param thread_axis The thread axis to be bound to the loop |
236 | */ |
237 | TVM_DLL void Bind(ScheduleState self, const StmtSRef& loop_sref, const IterVar& thread_axis); |
238 | /*! |
239 | * \brief Unroll the input loop. It requires nothing |
240 | * \param self The state of the schedule |
241 | * \param loop_sref The loop to be unrolled |
242 | */ |
243 | TVM_DLL void Unroll(ScheduleState self, const StmtSRef& loop_sref); |
244 | /******** Schedule: Insert cache stages ********/ |
245 | /*! |
246 | * \brief Create a block that reads a buffer region into a read cache. It requires: |
247 | * 1) There is at most one block who writes the buffer in the scope. |
248 | * 2) The scope block have stage-pipeline property. |
249 | * \param self The state of the schedule |
250 | * \param block_sref The consumer block of the target buffer. |
251 | * \param read_buffer_index The index of the buffer in block's read region. |
252 | * \param storage_scope The target storage scope. |
253 | * \param consumer_blocks Array of blocks that consume the cache. |
254 | * \return The cache stage block. |
255 | */ |
256 | TVM_DLL StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index, |
257 | const String& storage_scope, const Array<StmtSRef> consumer_blocks = {}); |
258 | /*! |
259 | * \brief Create a block that writes a buffer region into a write cache. It requires: |
260 | * 1) There is only one block that writes the target buffer. |
261 | * 2) The scope block have stage-pipeline property. |
262 | * \param self The state of the schedule |
263 | * \param block_sref The producer of the buffer |
264 | * \param write_buffer_index The index of the buffer in block's write region |
265 | * \param storage_scope The target storage scope |
266 | * \param consumer_blocks Array of blocks that consume the cache. |
267 | * \return The cache stage block. |
268 | */ |
269 | TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, |
270 | const String& storage_scope, |
271 | const Array<StmtSRef> consumer_blocks = {}); |
272 | /*! |
273 | *! |
274 | * \brief Create 2 blocks that read&write a buffer region into a read/write cache. |
275 | * It requires the the target block both read & write the target buffer. |
276 | * \param self The state of the schedule |
277 | * \param block_sref The target block operates on the target buffer. |
278 | * \param read_buffer_index The index of the buffer in block's read region. |
279 | * \param storage_scope The target storage scope |
280 | * \return The cache stage blocks, cache read block together with cache write block. |
281 | */ |
282 | TVM_DLL Array<StmtSRef> CacheInplace(ScheduleState self, const StmtSRef& block_sref, |
283 | int read_buffer_index, const String& storage_scope); |
284 | /*! |
285 | * \brief Create a block to cache precomputed index for later use. |
286 | * if there is no index computation, keep unchanged. |
287 | * \param block_sref The target block |
288 | * \param storage_scope The storage scope of cached block |
289 | * \param cse_thresh The repeat threshold that determines a common sub expr |
290 | * \return The cache stage block. |
291 | */ |
292 | TVM_DLL Array<StmtSRef> CacheIndex(ScheduleState self, const StmtSRef& block_sref, |
293 | const String& storage_scope, int cse_thresh); |
294 | /*! |
295 | *! |
296 | * \brief Create a block that read/write a buffer region into a read/write cache with reindexing. |
297 | * The layout of the cache will be the same as by the iterators of the block that reads/writes the |
298 | * buffer. It requires: |
299 | * 1) There is only one block who reads/writes the target buffer |
300 | * 2) There is only one buffer load/store of this buffer in the block |
301 | * \param self The state of the schedule |
302 | * \param block_sref The block operates on the target buffer. |
303 | * \param buffer_index The index of the buffer in block's read or write region. |
304 | * \param buffer_index_type The type of the buffer index, kRead or kWrite. |
305 | * \return The reindex stage block. |
306 | */ |
307 | TVM_DLL StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buffer_index, |
308 | BufferIndexType buffer_index_type); |
309 | /******** Schedule: Compute location ********/ |
310 | /*! |
311 | * \brief Move a producer block under the specific loop, and regenerate the |
312 | * loops induced by the block so that the buffer region produced by the producer block could |
313 | * cover those regions consumed by its consumer blocks under the given loop. It requires: |
314 | * 1) `block` and `loop` are under the same scope, `loop` is not the ancestor of `block` |
315 | * 2) The scope block has stage-pipeline property |
316 | * 3) The subtree of the scope block, where the given block is in, satisfies the compact dataflow |
317 | * condition. i.e. all the blocks in the scope block's subtree must be either complete block or |
318 | * reduction block |
319 | * 4) The block is not an output block with regard to the scope block, i.e. the buffers written by |
320 | * the block are allocated under the scope block |
321 | * 5) All the consumers of the block are under the given loop |
322 | * |
323 | * \param self The schedule state |
324 | * \param block_sref The block to be moved |
325 | * \param loop_sref The loop where the block to be moved to |
326 | * \param index The block index of the loop body subtree blocks: |
327 | * - `index = -1` means inserted into the last possible insertion point; |
328 | * - `index = -2` means inserted into the first possible insertion point; |
329 | * - Otherwise, `index` is a nonnegative number that indicates the insertion point |
330 | */ |
331 | TVM_DLL void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref, |
332 | bool preserve_unit_loops, int index = -1); |
333 | /*! |
334 | * \brief Move a consumer block under the specific loop, and regenerate the |
335 | * loops induced by the block so that the buffer region consumed by the consumer block could |
336 | * cover those regions produced by its producer blocks under the given loop. It requires: |
337 | * 1) `block` and `loop` are under the same scope, `loop` is not the ancestor of `block` |
338 | * 2) The scope block has stage-pipeline property |
339 | * 3) The subtree of the scope block, where the given block is in, satisfies the compact dataflow |
340 | * condition. i.e. all the blocks in the scope block's subtree must be either complete block or |
341 | * reduction block |
342 | * 4) All the producers of the block are under the given loop |
343 | * |
344 | * \param self The schedule state |
345 | * \param block_sref The block to be moved |
346 | * \param loop_sref The loop where the block to be moved to |
347 | * \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1 |
348 | * \param index The block index of the loop body subtree blocks: |
349 | * - `index = -1` means inserted into the last possible insertion point; |
350 | * - `index = -2` means inserted into the first possible insertion point; |
351 | * - Otherwise, `index` is a nonnegative number that indicates the insertion point |
352 | */ |
353 | TVM_DLL void ReverseComputeAt(ScheduleState self, const StmtSRef& block_sref, |
354 | const StmtSRef& loop_sref, bool preserve_unit_loops, int index = -1); |
355 | /*! |
356 | * \brief Inline a block into its consumer(s). It requires: |
357 | * 1) The block is a complete non-root block, which only produces one buffer |
358 | * 2) The block must not be the only leaf in the scope. |
359 | * 3) The body of the block must be a BufferStore statement in the form of, |
360 | * A[i, j, k, ...] = ... |
361 | * where the indices of the LHS are all distinct atomic variables, |
362 | * and no variables other than those indexing variables are allowed in the statement. |
363 | * \param self The state of the schedule |
364 | * \param block_sref The sref to the block to be inlined to its consumer(s) |
365 | */ |
366 | TVM_DLL void ComputeInline(ScheduleState self, const StmtSRef& block_sref); |
367 | /*! |
368 | * \brief Inline a block into its only producer. It requires: |
369 | * 1) The block is a complete non-root block, which only produces and consumers one buffer |
370 | * 2) The block must not be the only leaf in the scope. |
371 | * 3) The only producer of the block is a read-after-write producer and a complete non-root block |
372 | * 4) The body of the block must be a BufferStore statement in the form of, |
373 | * B[f(i, j, k, ...)] = g(i, j, k, A[i, j, k, ...] ...) |
374 | * where the indices of each `BufferLoad` on the RHS are all distinct atomic variables, |
375 | * and no variables other than those indexing variables are allowed in the statement. |
376 | * \param self The state of the schedule |
377 | * \param block_sref The sref to the block to be inlined to its producer |
378 | */ |
379 | TVM_DLL void ReverseComputeInline(ScheduleState self, const StmtSRef& block_sref); |
380 | /******** Schedule: Reduction ********/ |
381 | /*! |
382 | * \brief Decompose a reduction block into two separate blocks. |
383 | * a) The init block, which is translated from the init statement of the reduction block; |
384 | * b) The update block, which is the original block without init statement. |
385 | * |
386 | * The init block is inserted right before the given loop. |
387 | * |
388 | * The schedule primitive requires: |
389 | * 1) The input block is a reduction block. |
390 | * 2) The input loop is the ancestor of the block. |
391 | * 3) The input loop is not lower than all the loops related to reduce block var. |
392 | * \param block_rv The reduction block to be decomposed |
393 | * \param loop_rv The loop above which the init block is inserted before. |
394 | * \return The init block |
395 | */ |
396 | TVM_DLL StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref, |
397 | const StmtSRef& loop_sref); |
398 | /*! |
399 | * \brief Factor a reduction block by the specified loop |
400 | * \details See python/tvm/tir/schedule/schedule.py |
401 | * \param self The state of the schedule |
402 | * \param loop_sref The loop outside block for which we want to do rfactor |
403 | * \param factor_axis The position where the new dimension is placed in the new introduced rfactor |
404 | * buffer. Suppose the original reduction block writes to buffer `B` with |
405 | * ndim(B) dimensions, then `factor_axis` should be in range `[-ndim(B) - 1, |
406 | * ndim(B)]`, and the negative index will be normalized to a non-negative one |
407 | * \return The sref of the rfactor block |
408 | */ |
409 | TVM_DLL StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int factor_axis); |
410 | /******** Schedule: Block annotation ********/ |
411 | /*! \brief The quad used by StorageAlign for (buffer_idx, axis, factor, offset) */ |
412 | using StorageAlignTuple = Array<Integer>; |
413 | /*! \brief A list of StorageAlignTuple, used by StorageAlign */ |
414 | using StorageAlignAnnotation = Array<StorageAlignTuple>; |
415 | /*! |
416 | * \brief Set alignment requirement for specific dimension such that |
417 | * stride[axis] == k * factor + offset for some k. This is useful to set memory layout for |
418 | * more friendly memory access pattern. For example, we can set alignment to be factor=2, |
419 | * offset=1 to avoid bank conflict for thread access on higher dimension in GPU shared |
420 | * memory. |
421 | * \param self The state of the schedule |
422 | * \param block_sref The producer block of the buffer |
423 | * \param buffer_index The index of the buffer in block's write region |
424 | * \param axis The dimension to be specified for alignment |
425 | * \param factor The factor multiple of alignment |
426 | * \param offset The required offset factor |
427 | */ |
428 | TVM_DLL void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_index, |
429 | int axis, int factor, int offset); |
430 | /*! |
431 | * \brief Set the storage scope of a buffer, where the buffer is specified by the a block and a |
432 | * write-index |
433 | * \param self The state of the schedule |
434 | * \param block_sref The sref of the producer block of the buffer |
435 | * \param buffer_index The index of the buffer in block's write region |
436 | * \param storage_scope The storage scope to be set |
437 | */ |
438 | TVM_DLL void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer_index, |
439 | const String& storage_scope); |
440 | /*! |
441 | * \brief Set the axis separator of a buffer, where the buffer is specified by a block and a read |
442 | * or write index |
443 | * \param block_rv The block that accesses the target buffer. |
444 | * \param buffer_index The index of the buffer in block's read or write region. |
445 | * \param buffer_index_type The type of the buffer index, kRead or kWrite. |
446 | * \param axis_separators The axis separator of the buffer |
447 | */ |
448 | TVM_DLL void SetAxisSeparator(ScheduleState self, const StmtSRef& block_sref, int buffer_index, |
449 | BufferIndexType buffer_index_type, |
450 | const Array<IntImm>& axis_separators); |
451 | |
452 | /******** Schedule: Blockize & Tensorize ********/ |
453 | |
454 | /*! |
455 | * \brief Convert the subtree rooted at a specific loop into a block. |
456 | * \param self The state of the schedule |
457 | * \param loop_sref The root of the subtree |
458 | * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings |
459 | * \return The new block |
460 | */ |
461 | TVM_DLL StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref, bool preserve_unit_iters); |
462 | |
463 | /*! |
464 | * \brief Tensorize the computation enclosed by loop with the tensor intrinsic. |
465 | * \param self The state of the schedule |
466 | * \param block_or_loop_sref The block or loop to be tensorized. |
467 | * \param intrin The tensor intrinsic. |
468 | * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings |
469 | */ |
470 | TVM_DLL void Tensorize(ScheduleState self, const StmtSRef& block_or_loop_sref, |
471 | const TensorIntrin& intrin, bool preserve_unit_iters); |
472 | |
473 | /******** Schedule: Annotation ********/ |
474 | /*! |
475 | * \brief Annotate a block/loop with a key value pair |
476 | * \param self The state of the schedule |
477 | * \param sref The block/loop sref to be annotated |
478 | * \param ann_key The annotation key |
479 | * \param ann_val The annotation value |
480 | */ |
481 | TVM_DLL void Annotate(ScheduleState self, const StmtSRef& sref, const String& ann_key, |
482 | const ObjectRef& ann_val); |
483 | /*! |
484 | * \brief Unannotate a block/loop's annotation with key ann_key |
485 | * \param self The state of the schedule |
486 | * \param sref The block/loop to be unannotated |
487 | * \param ann_key The annotation key |
488 | */ |
489 | TVM_DLL void Unannotate(ScheduleState self, const StmtSRef& sref, const String& ann_key); |
490 | |
491 | /******** Schedule: Layout transformation ********/ |
492 | /*! |
493 | * \brief Apply a transformation represented by IndexMap to buffer |
494 | * \details The indices and the access region to the target buffer is transformed by the given |
495 | * index_map. The index_map is also used to infer the new shape of the buffer. Buffer must be |
496 | * one of the parameter of the function, or allocated in some blocks (it cannot be a buffer |
497 | * subregion created via match_buffer). |
498 | * \param self The state of the schedule |
499 | * \param block_sref The block sref that accesses the target buffer. |
500 | * \param buffer_index The index of the buffer in block's read or write region. |
501 | * \param buffer_index_type The type of the buffer index, kRead or kWrite. |
502 | * \param index_map The transformation to apply. |
503 | * \param pad_value The value to write into padding introduced by the transformation. |
504 | */ |
505 | TVM_DLL void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_index, |
506 | BufferIndexType buffer_index_type, const IndexMap& index_map, |
507 | const Optional<IndexMap>& pad_value); |
508 | |
509 | /*! |
510 | * \brief Apply a transformation represented by IndexMap to block |
511 | * \details The block iters and the block body are transformed by the given index_map. |
512 | * Outer loops corresponding to each new block iter are regenerated. |
513 | * The index_map is required to be bijective affine since we need its inverse mapping. |
514 | * \param self The state of the schedule |
515 | * \param block_sref The block sref that refers to the block to be transformed |
516 | * \param index_map The transformation to apply. |
517 | */ |
518 | TVM_DLL void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, |
519 | const IndexMap& index_map); |
520 | |
521 | /******** Schedule: Padding ********/ |
522 | /*! |
523 | * \brief Decompose a padding block into a block filling const pad values and a block |
524 | * writing in-bound values. |
525 | * \param block_sref The block sref that match the padding pattern. |
526 | * \param loop_sref The loop above which the const filling block is inserted before. |
527 | * \return The padding value filling block sref. |
528 | */ |
529 | TVM_DLL StmtSRef DecomposePadding(ScheduleState self, const StmtSRef& block_sref, |
530 | const StmtSRef& loop_sref); |
531 | |
532 | /*! |
533 | * \brief Pad the computation of Einsum. |
534 | * \param self The state of the schedule |
535 | * \param block_sref The block sref that matches the Einsum pattern. |
536 | * \param padding The padding for each block iter. |
537 | */ |
538 | TVM_DLL void PadEinsum(ScheduleState self, const StmtSRef& block_sref, |
539 | const Array<Integer>& padding); |
540 | |
541 | /******** Schedule: Buffer transformation ********/ |
542 | /*! |
543 | * \brief Compute the target buffer via rolling buffering. |
544 | * \details This primitive selects the outermost rollable axis with a positive bound overlap that |
545 | * appears in the block's ancestor loops as `rolling axis`, fold and circularize the buffer along |
546 | * the rolling dimension, append block predicate to avoid recomputing overlapping elements. |
547 | * It requires: |
548 | * 1) The buffer to be an intermediate buffer defined via `alloc_buffer`. |
549 | * 2) The LCA of the producer and consumer of the buffer is a for loop, typically, |
550 | * the producer and consumer of the buffer are cascaded through compute_at. |
551 | * 3) The access region of the buffer has at least one dimension that contains |
552 | * a positive bound overlap. |
553 | * \param block_rv The producer block of the buffer. |
554 | * \param write_buffer_index The index of the buffer in block's write region. |
555 | */ |
556 | TVM_DLL void RollingBuffer(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index); |
557 | /******** Schedule: Misc ********/ |
558 | |
559 | } // namespace tir |
560 | } // namespace tvm |
561 | |
562 | #endif // TVM_TIR_SCHEDULE_PRIMITIVE_H_ |
563 | |