1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/tir/transform.h |
22 | * \brief TIR specific transformation passes. |
23 | */ |
24 | #ifndef TVM_TIR_TRANSFORM_H_ |
25 | #define TVM_TIR_TRANSFORM_H_ |
26 | |
27 | #include <tvm/ir/transform.h> |
28 | #include <tvm/target/target.h> |
29 | #include <tvm/tir/expr.h> |
30 | #include <tvm/tir/function.h> |
31 | |
32 | #include <string> |
33 | #include <vector> |
34 | |
35 | namespace tvm { |
36 | namespace tir { |
37 | namespace transform { |
38 | |
39 | using tvm::transform::Pass; |
40 | using tvm::transform::PassContext; |
41 | using tvm::transform::PassContextNode; |
42 | using tvm::transform::PassInfo; |
43 | using tvm::transform::PassInfoNode; |
44 | using tvm::transform::PassNode; |
45 | using tvm::transform::Sequential; |
46 | |
47 | /* |
48 | * \brief Create a function pass that optimizes PrimFuncs. |
49 | * |
50 | * \param pass_func The packed function that contains the optimization. |
51 | * \param opt_level The optimization level of the function pass. |
52 | * \param name The name of the function pass. |
53 | * \param required The list of the passes that the function pass is dependent on. |
54 | * |
55 | * \return The created function pass. |
56 | */ |
57 | TVM_DLL Pass CreatePrimFuncPass( |
58 | const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func, |
59 | int opt_level, String name, tvm::Array<String> required); |
60 | |
61 | /*! |
62 | * \brief Inject prefetch instructions into stmt. |
63 | * |
64 | * \return The pass. |
65 | */ |
66 | TVM_DLL Pass InjectPrefetch(); |
67 | |
68 | // TODO(tvm-team): consolidate configs to the PassContext |
69 | /*! |
70 | * \brief Flatten the multi-dimensional read/write |
71 | * to single dimensional Load/Store |
72 | * |
73 | * \param cache_line_size The size of CPU cache line. |
74 | * \param create_bound_attribute Whether to create bound attributes. |
75 | * |
76 | * \return The Pass |
77 | */ |
78 | TVM_DLL Pass StorageFlatten(int cache_line_size, bool create_bound_attribute = false); |
79 | |
80 | /*! |
81 | * \brief Inject copy intrinsics with optional pad. |
82 | * |
83 | * \param pragma_key The pragma key for hint of copy. |
84 | * \param fintrin The function with signature |
85 | * |
86 | * Stmt fintrin(Buffer src, |
87 | * Buffer dst, |
88 | * Array<Expr> pad_before, |
89 | * Array<Expr> pad_after, |
90 | * Expr pad_value) |
91 | * \return The pass. |
92 | */ |
93 | TVM_DLL Pass InjectCopyIntrin(String pragma_key, runtime::PackedFunc fintrin); |
94 | |
95 | /*! |
96 | * \brief Detect and insert sync points to co-processor. |
97 | * |
98 | * \return The pass. |
99 | */ |
100 | TVM_DLL Pass CoProcSync(); |
101 | |
102 | /*! |
103 | * \brief Lift common attrs with attr_key to outer scope. |
104 | * |
105 | * \param attr_key The attribute key to be checked. |
106 | * \return The pass. |
107 | */ |
108 | TVM_DLL Pass LiftAttrScope(String attr_key); |
109 | |
110 | /*! |
111 | * \brief partition loops in the stmt. |
112 | * |
113 | * \return The pass. |
114 | */ |
115 | TVM_DLL Pass LoopPartition(); |
116 | |
117 | /*! |
118 | * \brief Lower vectorization loops. |
119 | * |
120 | * \param enable_vectorize Whether vectorization is enabled. |
121 | * |
122 | * \return The pass. |
123 | */ |
124 | TVM_DLL Pass VectorizeLoop(bool enable_vectorize = true); |
125 | |
126 | /*! |
127 | * \brief Inject virtual thread loops. |
128 | * |
129 | * \return The pass. |
130 | */ |
131 | TVM_DLL Pass InjectVirtualThread(); |
132 | |
133 | /*! |
134 | * \brief Inject double buffer statements. |
135 | * |
136 | * \return The pass. |
137 | */ |
138 | TVM_DLL Pass InjectDoubleBuffer(); |
139 | |
140 | /*! |
141 | * \brief Rewrite storage allocation pattern. |
142 | * Moves the allocation to outer most possible scope. |
143 | * Trying to share space between allocations to make |
144 | * a static allocation plan when possible. |
145 | * |
146 | * \return The pass. |
147 | */ |
148 | TVM_DLL Pass StorageRewrite(); |
149 | |
150 | /*! |
151 | * \brief unroll the constant loop marked by unroll. |
152 | * This pass also automatically attach pragma unroll tag to loops which meets the standard. |
153 | * |
154 | * \return The pass. |
155 | */ |
156 | TVM_DLL Pass UnrollLoop(); |
157 | |
158 | /*! |
159 | * \brief Remove No Op from the Stmt. |
160 | * |
161 | * \return The pass. |
162 | */ |
163 | TVM_DLL Pass RemoveNoOp(); |
164 | |
165 | /*! |
166 | * \brief Detect and rewrite unsafe select that contains memory access. |
167 | * |
168 | * \return The pass. |
169 | */ |
170 | TVM_DLL Pass RewriteUnsafeSelect(); |
171 | |
172 | /*! |
173 | * \brief Run arithmetic simplifications on the statements and expressions. |
174 | * |
175 | * \return The pass. |
176 | */ |
177 | TVM_DLL Pass Simplify(); |
178 | |
179 | /*! |
180 | * \brief Instruments bound checkers. |
181 | * |
182 | * \return The pass. |
183 | */ |
184 | TVM_DLL Pass InstrumentBoundCheckers(); |
185 | |
186 | /*! |
187 | * \brief Transform the high-level PrimFunc to a low-level version |
188 | * that can be used as an API function. |
189 | * |
190 | * |
191 | * The main task of this function is to create code to : |
192 | * - Map the values in the api_args to Var that is required by body. |
193 | * - Insert assertions to check type/value of the passed arguments. |
194 | * |
195 | * \note |
196 | * The function signature have two cases |
197 | * |
198 | * let num_packed_args = len(api_args); |
199 | * |
200 | * if num_packed_args is zero: |
201 | * f() |
202 | * |
203 | * if num_packed_args is not zero: |
204 | * f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args, |
205 | * api_arg_k, api_arg_k+1, ... api_arg_n, |
206 | * TVMValue* out_ret_val, int* out_ret_tcode) |
207 | * |
208 | * where n == len(api_args), k == num_packed_args |
209 | * |
210 | * \return The pass. |
211 | */ |
212 | TVM_DLL Pass MakePackedAPI(); |
213 | |
214 | /*! |
215 | * \brief Transform the high-level PrimFunc to a C signature that can be used |
216 | * to call the operator directly. |
217 | * |
218 | * The main task of this function is to create code that maps the values in the |
219 | * api_args to Var that is required by body |
220 | * |
221 | * \return The pass. |
222 | */ |
223 | TVM_DLL Pass MakeUnpackedAPI(); |
224 | |
225 | /*! |
226 | * \brief Remap the thread axis |
227 | * |
228 | * This can be used to get equivalent program which uses |
229 | * threadIdx.y in place of threadIdx.x by passing |
230 | * {"threadIdx.x": thread_axis("threadIdx.y")} |
231 | * |
232 | * |
233 | * \return The pass. |
234 | */ |
235 | TVM_DLL Pass RemapThreadAxis(Map<String, IterVar> axis_map); |
236 | |
237 | /*! |
238 | * \brief Lower custom datatypes. |
239 | * |
240 | * See tvm::datatypes::Registry for more information on adding custom datatypes. |
241 | * |
242 | * \return The pass. |
243 | */ |
244 | TVM_DLL Pass LowerCustomDatatypes(); |
245 | |
246 | /*! |
247 | * \brief Decorate all the function's body as device function. |
248 | * |
249 | * \return The pass. |
250 | */ |
251 | TVM_DLL Pass DecorateDeviceScope(); |
252 | |
253 | /*! |
254 | * \brief Split the function into a host function and device functions. |
255 | * |
256 | * \return The pass. |
257 | */ |
258 | TVM_DLL Pass SplitHostDevice(); |
259 | |
260 | /*! |
261 | * \brief skip assert stmt. |
262 | * |
263 | * \return The pass. |
264 | */ |
265 | TVM_DLL Pass SkipAssert(); |
266 | |
267 | /*! |
268 | * \brief Insert sync between parallel read/write of shared buffers. |
269 | * |
270 | * \param storage_scope The storage scope considered. |
271 | * \return The pass. |
272 | */ |
273 | TVM_DLL Pass ThreadSync(String storage_scope); |
274 | |
275 | /*! |
276 | * \brief Lower cross thread alleduce. |
277 | * |
278 | * \return The pass. |
279 | */ |
280 | TVM_DLL Pass LowerThreadAllreduce(); |
281 | |
282 | /*! |
283 | * \brief Infer the TensorCore fragment infomation using tensor intrinsics |
284 | * |
285 | * \return The pass. |
286 | */ |
287 | TVM_DLL Pass InferFragment(); |
288 | |
289 | /*! |
290 | * \brief This annotation is for nodes to be disabled for builtin lowering |
291 | */ |
292 | static constexpr const char* kDisableLowerTVMBuiltin = "disable_lower_builtin" ; |
293 | |
294 | /*! |
295 | * \brief Lower builtin intrinsics. |
296 | * \return The pass. |
297 | */ |
298 | TVM_DLL Pass LowerTVMBuiltin(); |
299 | |
300 | /*! |
301 | * \brief Lower the target specific function intrinsics in each of the function. |
302 | * |
303 | * \return The pass. |
304 | */ |
305 | TVM_DLL Pass LowerIntrin(); |
306 | |
307 | /*! |
308 | * \brief Lower warp memory access to low-level device related function calls. |
309 | * \return The pass. |
310 | */ |
311 | TVM_DLL Pass LowerWarpMemory(); |
312 | |
313 | /*! |
314 | * \brief Lower attached storage access information on device. |
315 | * |
316 | * \note Run this pass after all storage access analysis finish. |
317 | * |
318 | * \return The pass. |
319 | */ |
320 | TVM_DLL Pass LowerDeviceStorageAccessInfo(); |
321 | |
322 | /*! |
323 | * \brief Combine context calls in the host function. |
324 | * |
325 | * \return The pass. |
326 | */ |
327 | TVM_DLL Pass CombineContextCall(); |
328 | |
329 | /*! |
330 | * \brief Narrow down PrimExpr datatype in stmt to target_bits. |
331 | * |
332 | * \param target_bits The target bits |
333 | * |
334 | * \note Run this pass after storage flatten. |
335 | * \return The pass. |
336 | */ |
337 | TVM_DLL Pass NarrowDataType(int target_bits); |
338 | |
339 | /*! |
340 | * \brief Legalize bf16 typed Ops. Add a cast to fp32 |
341 | * before Ops, then add a cast back to bf16. |
342 | * \return The pass. |
343 | */ |
344 | TVM_DLL Pass BF16Legalize(); |
345 | |
346 | /*! |
347 | * \brief Rewrite the pointer content type of arguments, |
348 | * as well as Alloc internal to the function to use |
349 | * the most frequently accessed type for load/store |
350 | * to avoid pointer casting in backend when possible. |
351 | * |
352 | * \return The pass. |
353 | */ |
354 | TVM_DLL Pass PointerValueTypeRewrite(); |
355 | |
356 | /*! |
357 | * \brief Hoist loop-invariant IfThenElse nodes to |
358 | * outside the elligible loops. |
359 | * |
360 | * \return The pass. |
361 | */ |
362 | TVM_DLL Pass HoistIfThenElse(); |
363 | |
364 | /*! |
365 | * \brief Hoist loop-invariant expressions nodes to |
366 | * outside the elligible loops. |
367 | * |
368 | * Can hoist conditionals used in IfThenElse statements and |
369 | * expressions, bindings of variables in Let statements and |
370 | * expressions, or boolean expressions, configurable to enable/disable |
371 | * each hoistable type. |
372 | * |
373 | * \return The pass. |
374 | */ |
375 | TVM_DLL Pass HoistExpression(); |
376 | |
377 | /*! |
378 | * \brief Lower cross-thread reduction from thread |
379 | * bindings to intrinsic function calls. |
380 | * \return The pass. |
381 | */ |
382 | TVM_DLL Pass LowerCrossThreadReduction(); |
383 | |
384 | /*! |
385 | * \brief Lower block init stmt into IfThenElse stmts |
386 | * \return The pass. |
387 | */ |
388 | TVM_DLL Pass LowerInitBlock(); |
389 | |
390 | /*! |
391 | * \brief Locate the buffer allocation to the exact position (usually is |
392 | * the lca of buffer access). This pass will inject opaque block |
393 | * with alloc_buffers at the allocation site. |
394 | * \return The pass. |
395 | */ |
396 | TVM_DLL Pass PlanAndUpdateBufferAllocationLocation(); |
397 | |
398 | /*! |
399 | * \brief Substitute all the block vars with the PrimExprs they are bound to, indicated by the |
400 | * corresponding iter_values in BlockRealize, for opaque blocks by removing all |
401 | *. the iter_values in BlockRealize and iter_vars in Block. |
402 | * \return The pass. |
403 | */ |
404 | TVM_DLL Pass ConvertBlocksToOpaque(); |
405 | |
406 | /*! |
407 | * \brief Compact the buffer access region by removing the buffer regions that are not accessed, |
408 | * i.e. narrowing the buffer shape and adjust the access region if necessary. |
409 | * |
410 | * Before narrowing, `B` is a `[16, 16]` buffer, but only a skinny vector `B[i, 0:16]` is accessed. |
411 | * |
412 | * \code |
413 | * |
414 | * for i in range(0, 16): |
415 | * with T.block(): |
416 | * B = T.alloc_buffer(16, 16) |
417 | * for j in range(0, 16): |
418 | * B[i, j] = A[i, j] + 1 |
419 | * for j in range(0, 16): |
420 | * C[i, j] = B[i, j] + 1 |
421 | * |
422 | * \endcode |
423 | * |
424 | * This pass narrows the buffer shape and adjust its accessed region accordingly. |
425 | * In this particular case, because only a `1 * 16` vector of `B` is accessed, |
426 | * the pass narrows `B` to shape `[1, 16]`, and changes the access to `B[i, j]` to `B[0, j]`. |
427 | * |
428 | * \code |
429 | * |
430 | * for i in range(0, 16): |
431 | * with T.block(): |
432 | * B = T.alloc_buffer(1, 16) |
433 | * for j in range(0, 16): |
434 | * B[0, j] = A[i, j] + 1 |
435 | * for j in range(0, 16): |
436 | * C[i, j] = B[0, j] + 1 |
437 | * |
438 | * \endcode |
439 | * |
440 | * |
441 | * \return The pass. |
442 | */ |
443 | TVM_DLL Pass CompactBufferAllocation(); |
444 | |
445 | /*! |
446 | * This pass legalizes packed calls by wrapping their arguments into TVMValues |
447 | */ |
448 | TVM_DLL Pass LegalizePackedCalls(); |
449 | |
450 | /*! |
451 | * \brief Remove match buffers inside the block. Also, it will validate the binding. |
452 | * \return The pass. |
453 | */ |
454 | TVM_DLL Pass LowerMatchBuffer(); |
455 | |
456 | /*! |
457 | * \brief Remove the block to ensure that the TIR can not be scheduled again. |
458 | * \return The pass. |
459 | */ |
460 | TVM_DLL Pass LowerOpaqueBlock(); |
461 | |
462 | /*! |
463 | * \brief Flatten the multi-dimensional BufferLoad and BufferStore to single dimensional |
464 | * BufferLoad/BufferStore for the TIR not contains opaque block. |
465 | * \return The pass. |
466 | */ |
467 | TVM_DLL Pass FlattenBuffer(); |
468 | |
469 | /* |
470 | * \brief Flatten the multi-dimensional read/write |
471 | * to two dimensional texture Load/Store and realize |
472 | * texture buffer allocations. |
473 | * |
474 | * \return The Pass |
475 | */ |
476 | TVM_DLL Pass TextureFlatten(); |
477 | |
478 | /* |
479 | * \brief Lower VTCM allocations |
480 | * |
481 | * \return The Pass |
482 | */ |
483 | TVM_DLL Pass LowerVtcmAlloc(); |
484 | |
485 | /*! |
486 | * \brief Lower Async TIR primitives to DMA copy and wait builtins |
487 | */ |
488 | TVM_DLL Pass LowerAsyncDMA(); |
489 | |
490 | /*! |
491 | * \brief Implements a Common Subexpression Elimination (CSE) for TIR |
492 | * which introduces let-in bindings for duplicated sub-expressions. |
493 | * \param enable_cse_tir Whether common subexpression elimination is enabled. |
494 | * \param identify_equiv_terms Whether equivalent terms should be identified. |
495 | * \return The pass. |
496 | */ |
497 | TVM_DLL Pass CommonSubexprElimTIR(bool enable_cse_tir = true, bool identify_equiv_terms = false); |
498 | |
499 | /*! |
500 | * \brief Add TIR-printer output as debug information to all ops in the module |
501 | * \return The pass. |
502 | */ |
503 | |
504 | TVM_DLL Pass InstallDebugSpans(); |
505 | |
506 | /*! |
507 | * \brief Unify all the thread bindings for "blockIdx.x/y/z", "threadIdx.x/y/z", and |
508 | * "vthread.x/y/z". Before the unification, two vars that are bound to a thread axis (e.g., |
509 | * "threadIdx.x") use different IterVars and variables in their AttrStmts. After the |
510 | * unification, we use a consolidated IterVar and a variable for them. |
511 | * \return The pass. |
512 | * \note `vthread` is a legacy behavior that will be deprecated, though thread bindings of `vthread` |
513 | * are still also unified in this pass. Please use `vthread.x`, `vthread.y` and `vthread.z` |
514 | * instead. |
515 | */ |
516 | TVM_DLL Pass UnifyThreadBinding(); |
517 | |
518 | /*! |
519 | * A pass to merge multiple TIR-level dynamic shared memory allocations into one |
520 | */ |
521 | TVM_DLL Pass MergeDynamicSharedMemoryAllocations(); |
522 | |
523 | /*! |
524 | * \brief This pass is post-scheduling pass to convert all |
525 | * Parallel For loops to Serial ones. This is run |
526 | * to attain lesser memory and/or executor/backend |
527 | * does not support parallel launch of For loops. |
528 | * \return The pass. |
529 | */ |
530 | TVM_DLL Pass ConvertForLoopsToSerial(); |
531 | |
532 | /*! |
533 | * \brief This is the unified static memory planner pass that will |
534 | * plan for memory intra- and inter- PrimFuncs together. The pass |
535 | * requires all the function to be PrimFuncs including the main. |
536 | * \return The pass. |
537 | */ |
538 | TVM_DLL Pass UnifiedStaticMemoryPlanner(); |
539 | |
540 | /*! |
541 | * \brief This pass transforms annotated loops into pipelined ones where producers and consumers |
542 | * are overlapped with the information provided in loop annotations, which enables optimization |
543 | * techniques like prefetching and pipeline parallelism. |
544 | * |
545 | * The pipeline scope consists of the direct children of the annotated loop (ignoring BlockRealize, |
546 | * Block, SeqStmt), and the number of children is denoted by `n` in the documentation. |
547 | * |
548 | * The following annotations are used to guide the loop transformation: |
549 | * |
550 | * 1) Loop annotation `software_pipeline_stage` defines the pipeline stage. |
551 | * An array of `n` integers, and each element should be in range [0, max_stage], |
552 | * where max_stage is the maximum (inclusive) stage. |
553 | * 2) Loop annotation `software_pipeline_order` defines the pipeline order. |
554 | * An array of `n` integers, a permutation of [0, 1, ..., num_components - 1]; |
555 | * 3) Block annotation `double_buffer_scope` controls certain buffer sizes to allow decoupling of |
556 | * read/write dependency. It's an integer index of the write regions of the block. |
557 | * |
558 | * Every annotated loop is transformed into a loop with three blocks as its direct children: |
559 | * |
560 | * 1) Prologue block, where components whose stage is less than `max_stage` is executed; |
561 | * |
562 | * 2) Body block, where all the components are executed; |
563 | * |
564 | * 3) Epilogue block, where only components whose stage is greater than 0 will be executed. |
565 | * The execution order is controlled by the annotation `software_pipeline_order`, |
566 | * and thus could be different than the original order. |
567 | * |
568 | * Note: For nested software pipelines, the inner software pipeline will be generated first, |
569 | * which may affect the number of the direct children of the outer loop. |
570 | * In this case, the annotations for the outer software |
571 | * pipeline should include the result of the inner software pipeline, |
572 | * which is the three blocks as discussed above. |
573 | * Example: |
574 | * |
575 | * Before this pass, the TIR is: |
576 | * |
577 | * \code{.py} |
578 | * @T.prim_func |
579 | * def before_transform(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")) -> None: |
580 | * for tx in T.thread_binding(0, 16, thread="threadIdx.x"): |
581 | * for i in T.serial(0, 16, |
582 | * annotations={"software_pipeline_stage": [0, 1], |
583 | * "software_pipeline_order": [0, 1]} |
584 | * ): |
585 | * with T.block(): |
586 | * T.reads(A[tx, i]) |
587 | * T.writes(C[tx, i]) |
588 | * B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") |
589 | * with T.block("B"): |
590 | * T.reads(A[tx, i]) |
591 | * T.writes(B[tx, 0]) |
592 | * B[tx, 0] = A[tx, i] * T.float32(2) |
593 | * with T.block("C"): |
594 | * T.reads(B[tx, 0]) |
595 | * T.writes(C[tx, i]) |
596 | * C[tx, i] = B[tx, 0] + T.float32(1) |
597 | * \endcode |
598 | * |
599 | * The TIR above annotates the loop as a two-stage pipeline with no reordering. |
600 | * After applying this pass, the TIR is transformed into: |
601 | * |
602 | * \code{.py} |
603 | * @T.prim_func |
604 | * def after_transform(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")) -> None: |
605 | * for tx in T.thread_binding(0, 16, thread="threadIdx.x"): |
606 | * with T.block(): |
607 | * T.reads([A[tx, 0:16]]) |
608 | * T.writes([C[tx, 0:16]]) |
609 | * B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") |
610 | * with T.block("prologue"): |
611 | * T.reads([A[tx, 0]]) |
612 | * T.writes([B[0, tx, 0]]) |
613 | * B[0, tx, 0] = A[tx, 0] * T.float32(2) |
614 | * with T.block("body"): |
615 | * T.reads([A[tx, 1:16], B[0:2, tx, 0]]) |
616 | * T.writes([B[0:2, tx, 0], C[tx, 0:15]]) |
617 | * for i in T.serial(0, 15): |
618 | * with T.block("B"): |
619 | * T.reads([A[tx, i + 1]]) |
620 | * T.writes([B[(i + 1) % 2, tx, 0]]) |
621 | * B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2) |
622 | * with T.block("C"): |
623 | * T.reads([B[i % 2, tx, 0]]) |
624 | * T.writes([C[tx, i]]) |
625 | * C[tx, i] = B[i % 2, tx, 0] + T.float32(1) |
626 | * with T.block("epilogue"): |
627 | * T.reads([B[1, tx, 0]]) |
628 | * T.writes([C[tx, 15]]) |
629 | * C[tx, 15] = B[1, tx, 0] + T.float32(1) |
630 | * \endcode |
631 | * |
632 | * The original loop has two blocks, B and C, as its direct children. The loop annotations indicate |
633 | * that block B has stage == 0, order == 0, block C has stage == 1, order == 1. Therefore, block B |
634 | * should be executed in advance of block C by one iteration. The order 0 and 1 specifies the order |
635 | * of block B and C inside the body block inside the result TIR. |
636 | * |
637 | * \return The IR transform pass. |
638 | */ |
639 | TVM_DLL Pass InjectSoftwarePipeline(); |
640 | |
641 | TVM_DLL Pass BindParams(const Array<runtime::NDArray>& constants); |
642 | |
643 | /*! |
644 | * \brief Pass to collect tir non-scalar constants into module's 'Constants' attribute. |
645 | * |
646 | * \return The pass. |
647 | */ |
648 | TVM_DLL Pass (); |
649 | |
650 | /*! |
651 | * \brief Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv()) |
652 | * \return The pass. |
653 | */ |
654 | TVM_DLL Pass RenormalizeSplitPattern(); |
655 | |
656 | /*! |
657 | * \brief Annotate a PrimFunc with a given target. |
658 | * \return The pass. |
659 | */ |
660 | TVM_DLL Pass BindTarget(Target target); |
661 | |
662 | /*! |
663 | * \brief Set a PrimFunc as the entry point if it is only function in IRModule. |
664 | * \return The pass. |
665 | */ |
666 | TVM_DLL Pass AnnotateEntryFunc(); |
667 | |
668 | /*! |
669 | * \brief Filter PrimFuncs with a given condition. |
670 | * \return The pass. |
671 | */ |
672 | TVM_DLL Pass Filter(runtime::TypedPackedFunc<bool(PrimFunc)> fcond); |
673 | |
674 | /*! |
675 | * \brief Pass to rewrite global to shared memory copy on CUDA with asyncronous copy. |
676 | * \return The pass. |
677 | */ |
678 | TVM_DLL Pass InjectPTXAsyncCopy(); |
679 | |
680 | /*! |
681 | * \brief Remove the weight layout rewrite block |
682 | * \param skip_ndarray_rewrite If True, exact rewrite of NDArray, according to the given index map, |
683 | * will be skipped. Only the shape of the NDArray is transformed correctly, and the content of |
684 | * the destination array will be filled with random values. |
685 | * |
686 | * When this pass is called many times during MetaSchedule tuning, the raw data of NDArray, |
687 | * before and after rewrite, does not matter. Since NDArray layout rewrite, using IndexMap's |
688 | * MapNDArray, is currently slow, skipping the exact rewrite is sometimes necessary. |
689 | * |
690 | * \return The pass. |
691 | */ |
692 | TVM_DLL Pass RemoveWeightLayoutRewriteBlock(bool skip_ndarray_rewrite = false); |
693 | |
694 | /*! |
695 | * \brief Add the explicit local stage for the shared memory access on GPU. |
696 | * \return The pass. |
697 | */ |
698 | TVM_DLL Pass ManifestSharedMemoryLocalStage(); |
699 | |
700 | /*! |
701 | * \brief Insert intrinsic calls to instrument function and loop level profiling. |
702 | * \return The pass. |
703 | */ |
704 | TVM_DLL Pass InstrumentProfileIntrinsics(); |
705 | |
706 | } // namespace transform |
707 | } // namespace tir |
708 | } // namespace tvm |
709 | |
710 | #endif // TVM_TIR_TRANSFORM_H_ |
711 | |