1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/tir/transform.h
22 * \brief TIR specific transformation passes.
23 */
24#ifndef TVM_TIR_TRANSFORM_H_
25#define TVM_TIR_TRANSFORM_H_
26
27#include <tvm/ir/transform.h>
28#include <tvm/target/target.h>
29#include <tvm/tir/expr.h>
30#include <tvm/tir/function.h>
31
32#include <string>
33#include <vector>
34
35namespace tvm {
36namespace tir {
37namespace transform {
38
39using tvm::transform::Pass;
40using tvm::transform::PassContext;
41using tvm::transform::PassContextNode;
42using tvm::transform::PassInfo;
43using tvm::transform::PassInfoNode;
44using tvm::transform::PassNode;
45using tvm::transform::Sequential;
46
47/*
48 * \brief Create a function pass that optimizes PrimFuncs.
49 *
50 * \param pass_func The packed function that contains the optimization.
51 * \param opt_level The optimization level of the function pass.
52 * \param name The name of the function pass.
53 * \param required The list of the passes that the function pass is dependent on.
54 *
55 * \return The created function pass.
56 */
57TVM_DLL Pass CreatePrimFuncPass(
58 const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
59 int opt_level, String name, tvm::Array<String> required);
60
61/*!
62 * \brief Inject prefetch instructions into stmt.
63 *
64 * \return The pass.
65 */
66TVM_DLL Pass InjectPrefetch();
67
68// TODO(tvm-team): consolidate configs to the PassContext
69/*!
70 * \brief Flatten the multi-dimensional read/write
71 * to single dimensional Load/Store
72 *
73 * \param cache_line_size The size of CPU cache line.
74 * \param create_bound_attribute Whether to create bound attributes.
75 *
76 * \return The Pass
77 */
78TVM_DLL Pass StorageFlatten(int cache_line_size, bool create_bound_attribute = false);
79
80/*!
81 * \brief Inject copy intrinsics with optional pad.
82 *
83 * \param pragma_key The pragma key for hint of copy.
84 * \param fintrin The function with signature
85 *
86 * Stmt fintrin(Buffer src,
87 * Buffer dst,
88 * Array<Expr> pad_before,
89 * Array<Expr> pad_after,
90 * Expr pad_value)
91 * \return The pass.
92 */
93TVM_DLL Pass InjectCopyIntrin(String pragma_key, runtime::PackedFunc fintrin);
94
95/*!
96 * \brief Detect and insert sync points to co-processor.
97 *
98 * \return The pass.
99 */
100TVM_DLL Pass CoProcSync();
101
102/*!
103 * \brief Lift common attrs with attr_key to outer scope.
104 *
105 * \param attr_key The attribute key to be checked.
106 * \return The pass.
107 */
108TVM_DLL Pass LiftAttrScope(String attr_key);
109
110/*!
111 * \brief partition loops in the stmt.
112 *
113 * \return The pass.
114 */
115TVM_DLL Pass LoopPartition();
116
117/*!
118 * \brief Lower vectorization loops.
119 *
120 * \param enable_vectorize Whether vectorization is enabled.
121 *
122 * \return The pass.
123 */
124TVM_DLL Pass VectorizeLoop(bool enable_vectorize = true);
125
126/*!
127 * \brief Inject virtual thread loops.
128 *
129 * \return The pass.
130 */
131TVM_DLL Pass InjectVirtualThread();
132
133/*!
134 * \brief Inject double buffer statements.
135 *
136 * \return The pass.
137 */
138TVM_DLL Pass InjectDoubleBuffer();
139
140/*!
141 * \brief Rewrite storage allocation pattern.
142 * Moves the allocation to outer most possible scope.
143 * Trying to share space between allocations to make
144 * a static allocation plan when possible.
145 *
146 * \return The pass.
147 */
148TVM_DLL Pass StorageRewrite();
149
150/*!
151 * \brief unroll the constant loop marked by unroll.
152 * This pass also automatically attach pragma unroll tag to loops which meets the standard.
153 *
154 * \return The pass.
155 */
156TVM_DLL Pass UnrollLoop();
157
158/*!
159 * \brief Remove No Op from the Stmt.
160 *
161 * \return The pass.
162 */
163TVM_DLL Pass RemoveNoOp();
164
165/*!
166 * \brief Detect and rewrite unsafe select that contains memory access.
167 *
168 * \return The pass.
169 */
170TVM_DLL Pass RewriteUnsafeSelect();
171
172/*!
173 * \brief Run arithmetic simplifications on the statements and expressions.
174 *
175 * \return The pass.
176 */
177TVM_DLL Pass Simplify();
178
179/*!
180 * \brief Instruments bound checkers.
181 *
182 * \return The pass.
183 */
184TVM_DLL Pass InstrumentBoundCheckers();
185
186/*!
187 * \brief Transform the high-level PrimFunc to a low-level version
188 * that can be used as an API function.
189 *
190 *
191 * The main task of this function is to create code to :
192 * - Map the values in the api_args to Var that is required by body.
193 * - Insert assertions to check type/value of the passed arguments.
194 *
195 * \note
196 * The function signature have two cases
197 *
198 * let num_packed_args = len(api_args);
199 *
200 * if num_packed_args is zero:
201 * f()
202 *
203 * if num_packed_args is not zero:
204 * f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args,
205 * api_arg_k, api_arg_k+1, ... api_arg_n,
206 * TVMValue* out_ret_val, int* out_ret_tcode)
207 *
208 * where n == len(api_args), k == num_packed_args
209 *
210 * \return The pass.
211 */
212TVM_DLL Pass MakePackedAPI();
213
214/*!
215 * \brief Transform the high-level PrimFunc to a C signature that can be used
216 * to call the operator directly.
217 *
218 * The main task of this function is to create code that maps the values in the
219 * api_args to Var that is required by body
220 *
221 * \return The pass.
222 */
223TVM_DLL Pass MakeUnpackedAPI();
224
225/*!
226 * \brief Remap the thread axis
227 *
228 * This can be used to get equivalent program which uses
229 * threadIdx.y in place of threadIdx.x by passing
230 * {"threadIdx.x": thread_axis("threadIdx.y")}
231 *
232 *
233 * \return The pass.
234 */
235TVM_DLL Pass RemapThreadAxis(Map<String, IterVar> axis_map);
236
237/*!
238 * \brief Lower custom datatypes.
239 *
240 * See tvm::datatypes::Registry for more information on adding custom datatypes.
241 *
242 * \return The pass.
243 */
244TVM_DLL Pass LowerCustomDatatypes();
245
246/*!
247 * \brief Decorate all the function's body as device function.
248 *
249 * \return The pass.
250 */
251TVM_DLL Pass DecorateDeviceScope();
252
253/*!
254 * \brief Split the function into a host function and device functions.
255 *
256 * \return The pass.
257 */
258TVM_DLL Pass SplitHostDevice();
259
260/*!
261 * \brief skip assert stmt.
262 *
263 * \return The pass.
264 */
265TVM_DLL Pass SkipAssert();
266
267/*!
268 * \brief Insert sync between parallel read/write of shared buffers.
269 *
270 * \param storage_scope The storage scope considered.
271 * \return The pass.
272 */
273TVM_DLL Pass ThreadSync(String storage_scope);
274
275/*!
276 * \brief Lower cross thread alleduce.
277 *
278 * \return The pass.
279 */
280TVM_DLL Pass LowerThreadAllreduce();
281
282/*!
283 * \brief Infer the TensorCore fragment infomation using tensor intrinsics
284 *
285 * \return The pass.
286 */
287TVM_DLL Pass InferFragment();
288
289/*!
290 * \brief This annotation is for nodes to be disabled for builtin lowering
291 */
292static constexpr const char* kDisableLowerTVMBuiltin = "disable_lower_builtin";
293
294/*!
295 * \brief Lower builtin intrinsics.
296 * \return The pass.
297 */
298TVM_DLL Pass LowerTVMBuiltin();
299
300/*!
301 * \brief Lower the target specific function intrinsics in each of the function.
302 *
303 * \return The pass.
304 */
305TVM_DLL Pass LowerIntrin();
306
307/*!
308 * \brief Lower warp memory access to low-level device related function calls.
309 * \return The pass.
310 */
311TVM_DLL Pass LowerWarpMemory();
312
313/*!
314 * \brief Lower attached storage access information on device.
315 *
316 * \note Run this pass after all storage access analysis finish.
317 *
318 * \return The pass.
319 */
320TVM_DLL Pass LowerDeviceStorageAccessInfo();
321
322/*!
323 * \brief Combine context calls in the host function.
324 *
325 * \return The pass.
326 */
327TVM_DLL Pass CombineContextCall();
328
329/*!
330 * \brief Narrow down PrimExpr datatype in stmt to target_bits.
331 *
332 * \param target_bits The target bits
333 *
334 * \note Run this pass after storage flatten.
335 * \return The pass.
336 */
337TVM_DLL Pass NarrowDataType(int target_bits);
338
339/*!
340 * \brief Legalize bf16 typed Ops. Add a cast to fp32
341 * before Ops, then add a cast back to bf16.
342 * \return The pass.
343 */
344TVM_DLL Pass BF16Legalize();
345
346/*!
347 * \brief Rewrite the pointer content type of arguments,
348 * as well as Alloc internal to the function to use
349 * the most frequently accessed type for load/store
350 * to avoid pointer casting in backend when possible.
351 *
352 * \return The pass.
353 */
354TVM_DLL Pass PointerValueTypeRewrite();
355
356/*!
357 * \brief Hoist loop-invariant IfThenElse nodes to
358 * outside the elligible loops.
359 *
360 * \return The pass.
361 */
362TVM_DLL Pass HoistIfThenElse();
363
364/*!
365 * \brief Hoist loop-invariant expressions nodes to
366 * outside the elligible loops.
367 *
368 * Can hoist conditionals used in IfThenElse statements and
369 * expressions, bindings of variables in Let statements and
370 * expressions, or boolean expressions, configurable to enable/disable
371 * each hoistable type.
372 *
373 * \return The pass.
374 */
375TVM_DLL Pass HoistExpression();
376
377/*!
378 * \brief Lower cross-thread reduction from thread
379 * bindings to intrinsic function calls.
380 * \return The pass.
381 */
382TVM_DLL Pass LowerCrossThreadReduction();
383
384/*!
385 * \brief Lower block init stmt into IfThenElse stmts
386 * \return The pass.
387 */
388TVM_DLL Pass LowerInitBlock();
389
390/*!
391 * \brief Locate the buffer allocation to the exact position (usually is
392 * the lca of buffer access). This pass will inject opaque block
393 * with alloc_buffers at the allocation site.
394 * \return The pass.
395 */
396TVM_DLL Pass PlanAndUpdateBufferAllocationLocation();
397
398/*!
399 * \brief Substitute all the block vars with the PrimExprs they are bound to, indicated by the
400 * corresponding iter_values in BlockRealize, for opaque blocks by removing all
401 *. the iter_values in BlockRealize and iter_vars in Block.
402 * \return The pass.
403 */
404TVM_DLL Pass ConvertBlocksToOpaque();
405
406/*!
407 * \brief Compact the buffer access region by removing the buffer regions that are not accessed,
408 * i.e. narrowing the buffer shape and adjust the access region if necessary.
409 *
410 * Before narrowing, `B` is a `[16, 16]` buffer, but only a skinny vector `B[i, 0:16]` is accessed.
411 *
412 * \code
413 *
414 * for i in range(0, 16):
415 * with T.block():
416 * B = T.alloc_buffer(16, 16)
417 * for j in range(0, 16):
418 * B[i, j] = A[i, j] + 1
419 * for j in range(0, 16):
420 * C[i, j] = B[i, j] + 1
421 *
422 * \endcode
423 *
424 * This pass narrows the buffer shape and adjust its accessed region accordingly.
425 * In this particular case, because only a `1 * 16` vector of `B` is accessed,
426 * the pass narrows `B` to shape `[1, 16]`, and changes the access to `B[i, j]` to `B[0, j]`.
427 *
428 * \code
429 *
430 * for i in range(0, 16):
431 * with T.block():
432 * B = T.alloc_buffer(1, 16)
433 * for j in range(0, 16):
434 * B[0, j] = A[i, j] + 1
435 * for j in range(0, 16):
436 * C[i, j] = B[0, j] + 1
437 *
438 * \endcode
439 *
440 *
441 * \return The pass.
442 */
443TVM_DLL Pass CompactBufferAllocation();
444
445/*!
446 * This pass legalizes packed calls by wrapping their arguments into TVMValues
447 */
448TVM_DLL Pass LegalizePackedCalls();
449
450/*!
451 * \brief Remove match buffers inside the block. Also, it will validate the binding.
452 * \return The pass.
453 */
454TVM_DLL Pass LowerMatchBuffer();
455
456/*!
457 * \brief Remove the block to ensure that the TIR can not be scheduled again.
458 * \return The pass.
459 */
460TVM_DLL Pass LowerOpaqueBlock();
461
462/*!
463 * \brief Flatten the multi-dimensional BufferLoad and BufferStore to single dimensional
464 * BufferLoad/BufferStore for the TIR not contains opaque block.
465 * \return The pass.
466 */
467TVM_DLL Pass FlattenBuffer();
468
469/*
470 * \brief Flatten the multi-dimensional read/write
471 * to two dimensional texture Load/Store and realize
472 * texture buffer allocations.
473 *
474 * \return The Pass
475 */
476TVM_DLL Pass TextureFlatten();
477
478/*
479 * \brief Lower VTCM allocations
480 *
481 * \return The Pass
482 */
483TVM_DLL Pass LowerVtcmAlloc();
484
485/*!
486 * \brief Lower Async TIR primitives to DMA copy and wait builtins
487 */
488TVM_DLL Pass LowerAsyncDMA();
489
490/*!
491 * \brief Implements a Common Subexpression Elimination (CSE) for TIR
492 * which introduces let-in bindings for duplicated sub-expressions.
493 * \param enable_cse_tir Whether common subexpression elimination is enabled.
494 * \param identify_equiv_terms Whether equivalent terms should be identified.
495 * \return The pass.
496 */
497TVM_DLL Pass CommonSubexprElimTIR(bool enable_cse_tir = true, bool identify_equiv_terms = false);
498
499/*!
500 * \brief Add TIR-printer output as debug information to all ops in the module
501 * \return The pass.
502 */
503
504TVM_DLL Pass InstallDebugSpans();
505
506/*!
507 * \brief Unify all the thread bindings for "blockIdx.x/y/z", "threadIdx.x/y/z", and
508 * "vthread.x/y/z". Before the unification, two vars that are bound to a thread axis (e.g.,
509 * "threadIdx.x") use different IterVars and variables in their AttrStmts. After the
510 * unification, we use a consolidated IterVar and a variable for them.
511 * \return The pass.
512 * \note `vthread` is a legacy behavior that will be deprecated, though thread bindings of `vthread`
513 * are still also unified in this pass. Please use `vthread.x`, `vthread.y` and `vthread.z`
514 * instead.
515 */
516TVM_DLL Pass UnifyThreadBinding();
517
518/*!
519 * A pass to merge multiple TIR-level dynamic shared memory allocations into one
520 */
521TVM_DLL Pass MergeDynamicSharedMemoryAllocations();
522
523/*!
524 * \brief This pass is post-scheduling pass to convert all
525 * Parallel For loops to Serial ones. This is run
526 * to attain lesser memory and/or executor/backend
527 * does not support parallel launch of For loops.
528 * \return The pass.
529 */
530TVM_DLL Pass ConvertForLoopsToSerial();
531
532/*!
533 * \brief This is the unified static memory planner pass that will
534 * plan for memory intra- and inter- PrimFuncs together. The pass
535 * requires all the function to be PrimFuncs including the main.
536 * \return The pass.
537 */
538TVM_DLL Pass UnifiedStaticMemoryPlanner();
539
540/*!
541 * \brief This pass transforms annotated loops into pipelined ones where producers and consumers
542 * are overlapped with the information provided in loop annotations, which enables optimization
543 * techniques like prefetching and pipeline parallelism.
544 *
545 * The pipeline scope consists of the direct children of the annotated loop (ignoring BlockRealize,
546 * Block, SeqStmt), and the number of children is denoted by `n` in the documentation.
547 *
548 * The following annotations are used to guide the loop transformation:
549 *
550 * 1) Loop annotation `software_pipeline_stage` defines the pipeline stage.
551 * An array of `n` integers, and each element should be in range [0, max_stage],
552 * where max_stage is the maximum (inclusive) stage.
553 * 2) Loop annotation `software_pipeline_order` defines the pipeline order.
554 * An array of `n` integers, a permutation of [0, 1, ..., num_components - 1];
555 * 3) Block annotation `double_buffer_scope` controls certain buffer sizes to allow decoupling of
556 * read/write dependency. It's an integer index of the write regions of the block.
557 *
558 * Every annotated loop is transformed into a loop with three blocks as its direct children:
559 *
560 * 1) Prologue block, where components whose stage is less than `max_stage` is executed;
561 *
562 * 2) Body block, where all the components are executed;
563 *
564 * 3) Epilogue block, where only components whose stage is greater than 0 will be executed.
565 * The execution order is controlled by the annotation `software_pipeline_order`,
566 * and thus could be different than the original order.
567 *
568 * Note: For nested software pipelines, the inner software pipeline will be generated first,
569 * which may affect the number of the direct children of the outer loop.
570 * In this case, the annotations for the outer software
571 * pipeline should include the result of the inner software pipeline,
572 * which is the three blocks as discussed above.
573 * Example:
574 *
575 * Before this pass, the TIR is:
576 *
577 * \code{.py}
578 * @T.prim_func
579 * def before_transform(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")) -> None:
580 * for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
581 * for i in T.serial(0, 16,
582 * annotations={"software_pipeline_stage": [0, 1],
583 * "software_pipeline_order": [0, 1]}
584 * ):
585 * with T.block():
586 * T.reads(A[tx, i])
587 * T.writes(C[tx, i])
588 * B = T.alloc_buffer((16, 1), dtype="float32", scope="shared")
589 * with T.block("B"):
590 * T.reads(A[tx, i])
591 * T.writes(B[tx, 0])
592 * B[tx, 0] = A[tx, i] * T.float32(2)
593 * with T.block("C"):
594 * T.reads(B[tx, 0])
595 * T.writes(C[tx, i])
596 * C[tx, i] = B[tx, 0] + T.float32(1)
597 * \endcode
598 *
599 * The TIR above annotates the loop as a two-stage pipeline with no reordering.
600 * After applying this pass, the TIR is transformed into:
601 *
602 * \code{.py}
603 * @T.prim_func
604 * def after_transform(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")) -> None:
605 * for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
606 * with T.block():
607 * T.reads([A[tx, 0:16]])
608 * T.writes([C[tx, 0:16]])
609 * B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared")
610 * with T.block("prologue"):
611 * T.reads([A[tx, 0]])
612 * T.writes([B[0, tx, 0]])
613 * B[0, tx, 0] = A[tx, 0] * T.float32(2)
614 * with T.block("body"):
615 * T.reads([A[tx, 1:16], B[0:2, tx, 0]])
616 * T.writes([B[0:2, tx, 0], C[tx, 0:15]])
617 * for i in T.serial(0, 15):
618 * with T.block("B"):
619 * T.reads([A[tx, i + 1]])
620 * T.writes([B[(i + 1) % 2, tx, 0]])
621 * B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2)
622 * with T.block("C"):
623 * T.reads([B[i % 2, tx, 0]])
624 * T.writes([C[tx, i]])
625 * C[tx, i] = B[i % 2, tx, 0] + T.float32(1)
626 * with T.block("epilogue"):
627 * T.reads([B[1, tx, 0]])
628 * T.writes([C[tx, 15]])
629 * C[tx, 15] = B[1, tx, 0] + T.float32(1)
630 * \endcode
631 *
632 * The original loop has two blocks, B and C, as its direct children. The loop annotations indicate
633 * that block B has stage == 0, order == 0, block C has stage == 1, order == 1. Therefore, block B
634 * should be executed in advance of block C by one iteration. The order 0 and 1 specifies the order
635 * of block B and C inside the body block inside the result TIR.
636 *
637 * \return The IR transform pass.
638 */
639TVM_DLL Pass InjectSoftwarePipeline();
640
641TVM_DLL Pass BindParams(const Array<runtime::NDArray>& constants);
642
643/*!
644 * \brief Pass to collect tir non-scalar constants into module's 'Constants' attribute.
645 *
646 * \return The pass.
647 */
648TVM_DLL Pass ExtractPrimFuncConstants();
649
650/*!
651 * \brief Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv())
652 * \return The pass.
653 */
654TVM_DLL Pass RenormalizeSplitPattern();
655
656/*!
657 * \brief Annotate a PrimFunc with a given target.
658 * \return The pass.
659 */
660TVM_DLL Pass BindTarget(Target target);
661
662/*!
663 * \brief Set a PrimFunc as the entry point if it is only function in IRModule.
664 * \return The pass.
665 */
666TVM_DLL Pass AnnotateEntryFunc();
667
668/*!
669 * \brief Filter PrimFuncs with a given condition.
670 * \return The pass.
671 */
672TVM_DLL Pass Filter(runtime::TypedPackedFunc<bool(PrimFunc)> fcond);
673
674/*!
675 * \brief Pass to rewrite global to shared memory copy on CUDA with asyncronous copy.
676 * \return The pass.
677 */
678TVM_DLL Pass InjectPTXAsyncCopy();
679
680/*!
681 * \brief Remove the weight layout rewrite block
682 * \param skip_ndarray_rewrite If True, exact rewrite of NDArray, according to the given index map,
683 * will be skipped. Only the shape of the NDArray is transformed correctly, and the content of
684 * the destination array will be filled with random values.
685 *
686 * When this pass is called many times during MetaSchedule tuning, the raw data of NDArray,
687 * before and after rewrite, does not matter. Since NDArray layout rewrite, using IndexMap's
688 * MapNDArray, is currently slow, skipping the exact rewrite is sometimes necessary.
689 *
690 * \return The pass.
691 */
692TVM_DLL Pass RemoveWeightLayoutRewriteBlock(bool skip_ndarray_rewrite = false);
693
694/*!
695 * \brief Add the explicit local stage for the shared memory access on GPU.
696 * \return The pass.
697 */
698TVM_DLL Pass ManifestSharedMemoryLocalStage();
699
700/*!
701 * \brief Insert intrinsic calls to instrument function and loop level profiling.
702 * \return The pass.
703 */
704TVM_DLL Pass InstrumentProfileIntrinsics();
705
706} // namespace transform
707} // namespace tir
708} // namespace tvm
709
710#endif // TVM_TIR_TRANSFORM_H_
711