1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/relay/transform.h |
22 | * \brief Relay specific transformation passes. |
23 | */ |
24 | #ifndef TVM_RELAY_TRANSFORM_H_ |
25 | #define TVM_RELAY_TRANSFORM_H_ |
26 | |
27 | #include <tvm/ir/transform.h> |
28 | #include <tvm/relay/attrs/transform.h> |
29 | #include <tvm/relay/expr.h> |
30 | #include <tvm/relay/function.h> |
31 | #include <tvm/relay/op.h> |
32 | #include <tvm/relay/op_attr_types.h> |
33 | #include <tvm/target/compilation_config.h> |
34 | #include <tvm/target/target.h> |
35 | #include <tvm/target/virtual_device.h> |
36 | |
37 | #include <string> |
38 | |
39 | namespace tvm { |
40 | namespace relay { |
41 | namespace transform { |
42 | |
43 | using Pass = tvm::transform::Pass; |
44 | using PassNode = tvm::transform::PassNode; |
45 | using PassInfo = tvm::transform::PassInfo; |
46 | using PassInfoNode = tvm::transform::PassInfoNode; |
47 | using PassContext = tvm::transform::PassContext; |
48 | using PassContextNode = tvm::transform::PassContextNode; |
49 | using Sequential = tvm::transform::Sequential; |
50 | |
51 | /* |
52 | * \brief Create a function pass. |
53 | * |
54 | * \param pass_func The packed function that contains the optimization. |
55 | * \param opt_level The optimization level of the function pass. |
56 | * \param name The name of the function pass. |
57 | * \param required The list of the passes that the function pass is dependent on. |
58 | * |
59 | * \return The created function pass. |
60 | */ |
61 | TVM_DLL Pass CreateFunctionPass( |
62 | const runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>& pass_func, |
63 | int opt_level, String name, tvm::Array<String> required); |
64 | |
65 | /*! \brief Remove let-bound expressions which do not effect the program result. |
66 | * |
67 | * This pass will remove let bindings which are not referenced. If inline_once is True, |
68 | * let bindings which are only referenced once will also be inlined. |
69 | * |
70 | * For example, this pass should turn `let a = 1; 2` into `2`, |
71 | * as the value of the expression does not depend on a. |
72 | * |
73 | * As another example, `let a = 1; a` will be optimized into 1 if inline_once is True. |
74 | * |
75 | * If ignore_purity is False, possibly side-effecting expressions (such as memory allocation, |
76 | * random number generation, reading/writing references, or calls to primitive or external |
77 | * functions) are never elided or inlined. This is sound, but ignore_purity can be set to True |
78 | * to suppress this check. |
79 | * |
80 | * The analysis is fairly conservative, for example it assumes all local functions |
81 | * may be called more than once, any functions passed as arguments have side effects, |
82 | * and so on. |
83 | * |
84 | * \param inline_once whether or not to inline bindings used exactly once. |
85 | * \param ignore_purity whether to ignore whether expressions have side-effects |
86 | * |
87 | * \return the pass. |
88 | */ |
89 | TVM_DLL Pass DeadCodeElimination(bool inline_once = false, bool ignore_purity = false); |
90 | |
91 | /*! |
92 | * \brief Convert all expressions of TensorType into GradCell, |
93 | * an algebraic data type defined in gradient.rly. |
94 | * |
95 | * This will delay or decrease memory usage. All calls to |
96 | * ones, ones_like, zeros, zeros_like will not immediately instantiate a tensor in memory, |
97 | * rather only instantiate if needed. It also defines + and * operation |
98 | * between GradCell types which can increase performance when using |
99 | * zero-filled or one-filled tensors, which is the case in reverse mode ad. |
100 | * |
101 | * \return the pass |
102 | */ |
103 | TVM_DLL Pass LazyGradientInit(); |
104 | |
105 | /*! |
106 | * \brief Fold constant expressions. |
107 | * |
108 | * Because of backward compatibility reason it skips QNN primitives from folding by default. |
109 | * There are some transformation passes like FakeQuantizationToInteger, which requires to keep QNN |
110 | * primitives for constant subgraphs. Uncontrolled constant folding of QNN primitives may break |
111 | * applicability of FakeQuantizationToInteger. We suggest to use FoldConstant pass with none |
112 | * default fold_qnn=True value only when all other QNN sensitive passes were already applied. |
113 | * |
114 | * \param fold_qnn Whether to fold constants for QNN operations. |
115 | * |
116 | * \return The pass. |
117 | */ |
118 | TVM_DLL Pass FoldConstant(bool fold_qnn = false); |
119 | |
120 | /*! |
121 | * \brief Split function with huge number of arguments to smaller pieces. |
122 | * |
123 | * \return The pass. |
124 | */ |
125 | TVM_DLL Pass SplitArgs(int max_function_args); |
126 | |
127 | /*! |
128 | * \brief Fuse operations into expr into separate functions. |
129 | * |
130 | * \param fuse_opt_level Optimization level. If it is -1 it will be inferred from pass context. |
131 | * |
132 | * \return The pass. |
133 | */ |
134 | TVM_DLL Pass FuseOps(int fuse_opt_level = -1); |
135 | |
136 | /*! |
137 | * \brief The inverse operation of FuseOps. It transforms a fused program returned by |
138 | * FuseOps into the program before FuseOps. (i.e. x == DefuseOps(FuseOps(x))) |
139 | * |
140 | * \return The pass. |
141 | */ |
142 | TVM_DLL Pass DefuseOps(); |
143 | |
144 | /*! |
145 | * \brief Rewrite the annotated program. |
146 | * |
147 | * \param fallback_device The fallback device which is the default device for |
148 | * operators without annotation. |
149 | * |
150 | * \return The pass. |
151 | */ |
152 | TVM_DLL Pass RewriteAnnotatedOps(int fallback_device); |
153 | |
154 | /*! |
155 | * \brief Turn an expression to Basic Block Normal Form. |
156 | * |
157 | * We define a block as a group of expressions implied by the scope structure. |
158 | * |
159 | * Each graph node can only belong to a single block. |
160 | * |
161 | * For any value that is being used in multiple blocks, it has to be referred |
162 | * by a Var which is defined in a block, whose scope is the least common ancestor |
163 | * of blocks this value is used. |
164 | * |
165 | * \return The pass. |
166 | */ |
167 | TVM_DLL Pass ToBasicBlockNormalForm(); |
168 | |
169 | /*! |
170 | * \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF). |
171 | * |
172 | * It will turn an expression that is in a graph form (with sharing implicit), |
173 | * to an expression with explicit sharing (A-Normal Form). |
174 | * |
175 | * The scope of the root expression is the global scope. |
176 | * |
177 | * The scope of any non root expression is the least common ancestor of all it's scope. |
178 | * |
179 | * Values are ordered by post-DFS order in each scope. |
180 | * |
181 | * \return The pass. |
182 | */ |
183 | TVM_DLL Pass ToANormalForm(); |
184 | |
185 | /*! |
186 | * \brief ToANormalForm but on incomplete graph. |
187 | * |
188 | * \param expr the graph. |
189 | * |
190 | * \return The transformed program. |
191 | */ |
192 | TVM_DLL Expr ToANormalForm(const Expr& expr); |
193 | |
194 | /*! |
195 | * \brief Turn an expression into continuation passing style(CPS). |
196 | * |
197 | * CPS mean that every function will, instead of returning the result directly, |
198 | * be passed down an extra function (called the continuation) as argument, |
199 | * and pass the result to the continuation instead. |
200 | * |
201 | * Thus, every function call has to be passed an extra argument |
202 | * that represent the rest of the computation (Hence the name of continuation). |
203 | * |
204 | * Similarly, all other compute will be wrapped and call the continuation as well. |
205 | * |
206 | * \return the pass. |
207 | */ |
208 | TVM_DLL Pass ToCPS(); |
209 | |
210 | /*! |
211 | * \brief Remove let binding and directly share via pointer instead. |
212 | * |
213 | * It will remove all let binding, |
214 | * and turn all of the variable bound by let into direct pointer reference. |
215 | * |
216 | * \return the expression in graph normal form. |
217 | */ |
218 | TVM_DLL Pass ToGraphNormalForm(); |
219 | |
220 | /*! |
221 | * \brief Aggressive constant propagation/constant folding/inlining. |
222 | * |
223 | * It will do as much computation in compile time as possible. |
224 | * It has two benefit: remove runtime overhead, and allow more optimization (typically fusion). |
225 | * As a side effect, code size will explode. |
226 | * |
227 | * \return the optimized expression. |
228 | */ |
229 | TVM_DLL Pass PartialEval(); |
230 | |
231 | /*! |
232 | * \brief Simplify certain operators during inference. For example, the result |
233 | * of a batch norm which is indexed at tuple index 0 will be unpacked into a |
234 | * number of simplified operators. |
235 | * |
236 | * \return The Pass. |
237 | */ |
238 | TVM_DLL Pass SimplifyInference(); |
239 | |
240 | /*! |
241 | * \brief Replaces non linear activation functions with their fast but approximate counterparts. |
242 | * |
243 | * \return The Pass. |
244 | */ |
245 | TVM_DLL Pass FastMath(); |
246 | |
247 | /*! |
248 | * \brief Find Dynamic ops and make them static |
249 | * |
250 | * Searches the graph for dynamic ops. If the dynamic inputs to those ops are constants, it replaces |
251 | * them with static ops and re-performs type inference and constant folding. The pass repeats |
252 | * itself until the graph stops changing or we run too many iterations. |
253 | * |
254 | * \return The pass. |
255 | */ |
256 | TVM_DLL Pass DynamicToStatic(); |
257 | |
258 | /*! |
259 | * \brief Infer the type of an expression. |
260 | * |
261 | * The result of type checking is a new expression with unambiguous |
262 | * type information filled in, as well as it's checked type field |
263 | * populated with the result type. |
264 | * |
265 | * \return The pass. |
266 | */ |
267 | TVM_DLL Pass InferType(); |
268 | |
269 | /*! |
270 | * \brief Infer the type of an expression, reusing existing type information. |
271 | * |
272 | * The result of type checking is a new expression with unambiguous |
273 | * type information filled in for the given node only. The local |
274 | * version can use existing type information populated throughout |
275 | * the expression and assumes this information is correct. The local |
276 | * version also avoids examining large amounts of the graph assuming |
277 | * type information is filled in properly which makes it much faster if we |
278 | * iteratively call type inference. |
279 | * |
280 | * \return The type of the expression. |
281 | */ |
282 | TVM_DLL Type InferTypeLocal(const Expr& expr); |
283 | |
284 | /*! |
285 | * \brief Search and eliminate common subexpression. For example, if there are |
286 | * two expressions evaluated to an identical value, a single variable is created |
287 | * and these two expressions are replaced by this variable. |
288 | * |
289 | * \param fskip The callback argument that allows to skip certain expressions. |
290 | * |
291 | * \return The pass. |
292 | */ |
293 | TVM_DLL Pass EliminateCommonSubexpr(runtime::PackedFunc fskip = nullptr); |
294 | |
295 | /*! |
296 | * \brief Combine parallel 2d convolutions into a single convolution if the |
297 | * number of branches of this conv2d operator is not less than |
298 | * `min_num_branch`. |
299 | * |
300 | * \param min_num_branches The minimun number of branches. |
301 | * |
302 | * \return The pass. |
303 | */ |
304 | TVM_DLL Pass CombineParallelConv2D(uint64_t min_num_branches = 3); |
305 | |
306 | /*! |
307 | * \brief Combine parallel dense ops into a single batch_matmul if the |
308 | * number of branches of this dense operator is not less than |
309 | * `min_num_branch`. |
310 | * |
311 | * \param min_num_branches The minimun number of branches. |
312 | * \param to_batch_matmul Whether to combine parallel dense ops to batch matmul. |
313 | * If set false, combine dense ops to single dense op. |
314 | * |
315 | * \return The pass. |
316 | */ |
317 | TVM_DLL Pass CombineParallelDense(uint64_t min_num_branches = 3, bool to_batch_matmul = true); |
318 | |
319 | /*! |
320 | * \brief Combine parallel batch_matmul ops into a single batch_matmul |
321 | * if the number of branches of this dense operator is not less than |
322 | * `min_num_branch`. |
323 | * |
324 | * \param min_num_branches The minimun number of branches. |
325 | * |
326 | * \return The pass. |
327 | */ |
328 | TVM_DLL Pass CombineParallelBatchMatmul(uint64_t min_num_branches = 3); |
329 | |
330 | /*! |
331 | * \brief Backward fold axis scaling into weights of conv/dense operators. |
332 | * |
333 | * \return The pass. |
334 | */ |
335 | TVM_DLL Pass BackwardFoldScaleAxis(); |
336 | |
337 | /*! |
338 | * \brief Forward fold axis scaling into weights of conv/dense operators. |
339 | * |
340 | * \return The pass. |
341 | */ |
342 | TVM_DLL Pass ForwardFoldScaleAxis(); |
343 | |
344 | /*! |
345 | * \brief A sequential pass that executes ForwardFoldScaleAxis and |
346 | * BackwardFoldScaleAxis passes. |
347 | * |
348 | * \return The pass. |
349 | */ |
350 | TVM_DLL Pass FoldScaleAxis(); |
351 | |
352 | /*! |
353 | * \brief Canonicalize some operators to the simplified operators. For example, |
354 | * bias_add can be canonicalized to expand_dims and broadcast_add. |
355 | * |
356 | * \return The pass. |
357 | */ |
358 | TVM_DLL Pass CanonicalizeOps(); |
359 | |
360 | /*! |
361 | * \brief Alternate the layouts of operators or replace primitive operators |
362 | * with other expressions. |
363 | * |
364 | * \return The pass. |
365 | */ |
366 | TVM_DLL Pass AlterOpLayout(); |
367 | |
368 | /*! |
369 | * \brief Do layout rewrite according to the tile structure created by auto-scheduler. |
370 | * \return The pass |
371 | */ |
372 | TVM_DLL Pass AutoSchedulerLayoutRewrite(); |
373 | |
374 | /*! |
375 | * \brief Do layout rewrite according to the tile structure created by meta-schedule. |
376 | * \return The pass |
377 | */ |
378 | TVM_DLL Pass MetaScheduleLayoutRewrite(); |
379 | |
380 | /*! |
381 | * \brief Given a dest layout, this pass transforms the expr such that most of the ops input data |
382 | * layout is changed to the dest layout. In ideal situation, there are only 2 layout transforms, one |
383 | * at the start and one at the end. |
384 | * |
385 | * This pass is not a part of relay.build and is expected to be called between framework-relay |
386 | * parser and relay.build call. This is very helpful for hardware backends that support/prefer only |
387 | * type of data layout. |
388 | * |
389 | * RFC - https://discuss.tvm.ai/t/layout-conversion-pass/4009 |
390 | * |
391 | * This pass uses most of the AlterOpLayout and InferCorrectLayout infrastructure. We can define new |
392 | * layouts for conv2d ops for now. Most of the other operators try to adapt to their input layout |
393 | * using the InferCorrectLayout infrastructure. |
394 | * |
395 | * \param desired_layouts Specify mapping of op_name to array of desired layouts for each input. |
396 | * For example: Map("nn.conv2d", Array("NHWC", "OHWI")), |
397 | * this specifies the desired layout for data then kernel for nn.conv2d. |
398 | * \return The pass. |
399 | */ |
400 | TVM_DLL Pass ConvertLayout(const Map<String, Array<String>>& desired_layouts); |
401 | |
402 | /*! |
403 | * \brief Legalizes an expr with another expression. |
404 | * \param legalize_map_attr_name The Op's attr name which corresponds to the legalize rule function. |
405 | * One can collect and isolate similar type of legalize transformations using this param. For |
406 | * example, transformations that only apply to Dialects can be isolated into a FTVMDialectLegalize |
407 | * string. This pass calls only those transformations that have been registered using the supplied |
408 | * legalize_map_attr_name. |
409 | * |
410 | * \return The pass. |
411 | */ |
412 | TVM_DLL Pass Legalize(const String& legalize_map_attr_name = "FTVMLegalize" ); |
413 | |
414 | /*! |
415 | * \brief Canonicalize cast expressions to make operator fusion more efficient. |
416 | * |
417 | * \return The pass. |
418 | */ |
419 | TVM_DLL Pass CanonicalizeCast(); |
420 | |
421 | /*! |
422 | * \brief Add abstraction over a constructor or global variable bound to a function. |
423 | * |
424 | * For example: `square` is transformed to |
425 | * `fn (%x: int32) -> int32 { square(x) }`. |
426 | * |
427 | * See https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B7-conversion |
428 | * for more details. |
429 | * |
430 | * \param expand_constructor Whether to expand constructors. |
431 | * \param expand_global_var Whether to expand global variables. |
432 | * |
433 | * \return The pass. |
434 | */ |
435 | TVM_DLL Pass EtaExpand(bool expand_constructor, bool expand_global_var); |
436 | |
437 | /*! |
438 | * \brief Partition a Relay program into regions that can be executed on |
439 | * different backends. |
440 | * |
441 | * \return The pass. |
442 | */ |
443 | TVM_DLL Pass PartitionGraph(); |
444 | |
445 | /*! |
446 | * \brief Inline the global functions marked as `inline` in a given Relay |
447 | * IRModule. |
448 | * |
449 | * \return The pass. |
450 | */ |
451 | TVM_DLL Pass Inline(); |
452 | |
453 | /*! |
454 | * \brief Remove the unused functions in the Relay IRModule. |
455 | * |
456 | * \param entry_functions The entry functions used to search the functions that |
457 | * are being used. |
458 | * |
459 | * \return The pass. |
460 | */ |
461 | TVM_DLL Pass RemoveUnusedFunctions(Array<runtime::String> entry_functions); |
462 | |
463 | /*! |
464 | * \brief Simplify the Relay expression. |
465 | * |
466 | * \return The pass. |
467 | */ |
468 | TVM_DLL Pass SimplifyExpr(); |
469 | |
470 | /*! |
471 | * \brief Run any custom passes registered under "RelayToTIR" attributes on TargetKinds. |
472 | * |
473 | * This pass looks for inline, let-bound or global functions which have a "Compiler" attribute. |
474 | * If the attribute value corresponds to a TargetKind with a "RelayToTIR" attribute, then the |
475 | * 'custom' pass bound to that attribute is run (at most once) on the IRModule as a whole. |
476 | * |
477 | * If, in addition, the \p config has a Target with a matching TargetKind, that Target is set |
478 | * as the 'current' target before the custom pass is executed. In this way it is possible |
479 | * for custom passes to pick up target options which may guide how they transform the IRModule. |
480 | * (Those targets are referred to as 'extern codegen targets' elsewhere). |
481 | * |
482 | * A typical custom pass will: |
483 | * - Find calls to "Compiler" attributes functions with matching compiler name. |
484 | * - Lower those function to TIR PrimFuncs. |
485 | * - Bind those functions into the IRModule under the the functions' "global_symbol" attribute. |
486 | * - Replace all calls to those functions with 'call_lowered' to the matching global. |
487 | * Care should be taken to handle multiple calls to the same function. |
488 | * See src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc for an example custom pass. |
489 | * |
490 | * It is also possible (despite the pass and attribute names!) for the custom pass to proceed |
491 | * directly to a runtime::Module, which can be attached to the output IRModules "external_mods" |
492 | * attribute (taking care not to clobber any existing modules). In this case the flow is as above, |
493 | * except: |
494 | * - The runtime::Module must contain a binding for each compiled function under their |
495 | * "global_symbol" (ie runtime::Module::ImplementsFunction should return true). |
496 | * - A Relay Function must be bound (or re-bound) into the result IRModule, again with the same |
497 | * "global_symbol", but with only the "Extern" attribute set to Integer(1). The function body |
498 | * should be the original function body. In this way we always have a TVM definition matching |
499 | * every global function name. |
500 | * |
501 | * There are many existing runtime::Modules, ranging from source to object to dynamic libaries to |
502 | * entirely custom implementations. Some of those may require additional compilation using |
503 | * 'export_library' on the final build artifact. |
504 | * |
505 | * The OutlineCompilerFunctionsWithExistingGlobalSymbols and MarkCompilerFunctionsAsExtern utility |
506 | * passes can be used by custom passes to take care of some of the boilerplate. |
507 | * |
508 | * TODO(mbs): Rename PreLoweringTargetHooks? |
509 | * |
510 | * \param config All available targets. |
511 | * |
512 | * \return The pass. |
513 | */ |
514 | TVM_DLL Pass RelayToTIRTargetHook(CompilationConfig config); |
515 | |
516 | /*! |
517 | * \brief A pass for manifesting explicit memory allocations and rewriting |
518 | * specific dialects. |
519 | * |
520 | * \param cpu_virtual_device VirtualDevice for computations and data which must reside on a CPU, |
521 | * such as shapes and shape functions. |
522 | * |
523 | * \return The pass. |
524 | */ |
525 | TVM_DLL Pass ManifestAlloc(VirtualDevice cpu_virtual_device); |
526 | |
527 | /*! |
528 | * \brief A pass for manifesting variable lifetimes by inserting kill operations when variables |
529 | * become dead. This pass should be run after ManifestAlloc, and should not be run more than once. |
530 | * |
531 | * \return The pass. |
532 | */ |
533 | TVM_DLL Pass ManifestLifetimes(); |
534 | |
535 | /*! |
536 | * \brief Uses existing "on_device" and "device_copy" CallNodes to infer the \p VirtualDevice on |
537 | * which every Relay sub-expression should run and the result stored. Captures the result of that |
538 | * analysis using new "on_device" and "device_copy" CallNodes. |
539 | * |
540 | * See tvm::relay::transform::{LexicalOnDeviceMixin,DeviceAwareExprVisitor,DeviceAwareExprMutator} |
541 | * for help recovering the device for an arbitrary sub-expression in downstream transformations. |
542 | * |
543 | * \param config Describes the targets and default \p VirtualDevice for all primitive operators and |
544 | * host sub-expressions. |
545 | * |
546 | * \return The pass. |
547 | */ |
548 | TVM_DLL Pass PlanDevices(CompilationConfig config); |
549 | |
550 | /*! |
551 | * \brief This transform flattens atrous convolution, which corresponds to the sequence of |
552 | * operations: "space_to_batch_nd"->"conv2d"->"batch_to_space_nd" and convert them into subgraphs |
553 | * with a convolution with the modified "dilation" and recalculated "padding" parameters. |
554 | * |
555 | * \return The pass. |
556 | */ |
557 | TVM_DLL Pass FlattenAtrousConv(); |
558 | |
559 | /*! |
560 | * \brief Annotates the minimum required memory of each primitive function callsite by analyzing |
561 | * the liveness of the input/output tensors at each function callsite and calculating the total |
562 | * amount of memory these tensors require. This is added as a "used_memory" annotation to the |
563 | * function in question as a list of the number of bytes for each callsite. In addition, the |
564 | * containing function is annotated with an "io_used_memory" annotation which refers to the total |
565 | * memory required for the IO tensors. |
566 | * |
567 | * Note: This pass does not support dynamic shapes, it is the users responsibility to check this |
568 | * pass isn't applied where dynamic shapes may be input. |
569 | */ |
570 | TVM_DLL Pass AnnotateUsedMemory(); |
571 | |
572 | /*! |
573 | * \brief Captures the post-dfs index and dominator post-dfs index of (most) expression nodes in |
574 | * their span, in the form "index:<post-dfs index>:<dominator post-dfs index>". This is useful for |
575 | * debugging since a) it helps identify pretty-printed sub-expressions within the overall model |
576 | * and b) the indexes are heavily used by Collage for its compact representation of sub-graphs. |
577 | * |
578 | * Note that Op and Constructor nodes are not changed even though they are assigned an |
579 | * post-dfs index. |
580 | */ |
581 | TVM_DLL Pass CapturePostDfsIndexInSpans(); |
582 | |
583 | /*! |
584 | * \brief Calls device dependent memory scope analysis pass, collects mapping of desirable |
585 | * expr->memory_scope and annotates expressions by VirtualDevice with required memory_scope |
586 | */ |
587 | TVM_DLL Pass AnnotateMemoryScope(); |
588 | |
589 | /*! |
590 | * \brief Removes non-fused reshapes after lowering the graph. |
591 | * InferType() cannot be invoked after calling this pass as it removes reshapes from the call |
592 | * graph. Many targets only need buffer addresses irrespective of the shapes of them. This makes |
593 | * reshapes symbolic once the graph has been lowered. Reshape removal results into smaller code |
594 | * size and reduced buffer allocations. It opens up opportunities of operator fusion in the target |
595 | * backend. Thus, consequently, it improves the performance of the inference. |
596 | */ |
597 | TVM_DLL Pass RemoveStandaloneReshapes(); |
598 | |
599 | } // namespace transform |
600 | |
601 | /*! |
602 | * \brief Bind the free variables to a Relay expression. This is a helper |
603 | * function usually called by other pass functions to help optimizations. |
604 | * If any free variables are introduced into a function, those are added |
605 | * to the functoin parameters. |
606 | * Additionally this may change the order of parameters if you map a variable |
607 | * to a variable. |
608 | * |
609 | * \param expr The input expression. |
610 | * \param binds The variable to expression map that will be used to help the |
611 | * binding. |
612 | * |
613 | * \return The updated expression. |
614 | */ |
615 | TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds); |
616 | |
617 | /*! |
618 | * \brief Substitute variables with new variables (including function parameters) in a function. |
619 | * This is a helper function usually called by other pass functions to help optimizations. |
620 | * Expects all values in the bind map to be Vars. |
621 | * |
622 | * \param func The input function. |
623 | * \param binds The variable to expression map that will be used to help the |
624 | * binding. |
625 | * |
626 | * \return The updated expression. |
627 | */ |
628 | TVM_DLL Function SubstituteBoundVars(const Function& func, const tvm::Map<Var, Expr>& binds); |
629 | |
630 | /*! |
631 | * \brief Apply rewrite rules to rewrite the expr in post DFS order. This |
632 | * function is used as a helper function to rewrtie an expression in a pass. |
633 | * |
634 | * \param expr The expression. |
635 | * \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite |
636 | * rule function. |
637 | * \param fcontext Additional callback to provide context argument for each call node. |
638 | * \param fmulti_ref_trigger Transformation function to be called when |
639 | * an Expr consumed by multiple callers. |
640 | * \return The rewritten expression. |
641 | */ |
642 | TVM_DLL Expr ForwardRewrite(const Expr& expr, const String& rewrite_map_attr_name, |
643 | std::function<ObjectRef(const Call&)> fcontext = nullptr, |
644 | std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr); |
645 | |
646 | /*! |
647 | * \brief Apply rewrite rules to rewrite the expr in post DFS order. This |
648 | * function is used as a helper function to rewrtie an expression in a pass. |
649 | * |
650 | * \param expr The expression. |
651 | * \param rewrite_func The rewrite func that will apply to all operators. |
652 | * \param fcontext Additional callback to provide context argument for each call node. |
653 | * \param fmulti_ref_trigger Transformation function to be called when |
654 | * an Expr consumed by multiple callers. |
655 | * |
656 | * \return The rewritten expression. |
657 | */ |
658 | TVM_DLL Expr ForwardRewrite(const Expr& expr, const FForwardRewrite& rewrite_func, |
659 | std::function<ObjectRef(const Call&)> fcontext = nullptr, |
660 | std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr); |
661 | |
662 | /*! |
663 | * \brief Rewrite the annotated program. |
664 | * |
665 | * \param expr The expression. |
666 | * \param fallback_device The fallback device which is the default device for |
667 | * operators without annotation. |
668 | * |
669 | * \return The updated program. |
670 | */ |
671 | TVM_DLL Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device); |
672 | |
673 | /*! |
674 | * \brief Turn an expression into continuation passing style(CPS). |
675 | * |
676 | * CPS mean that every function will, instead of returning the result directly, |
677 | * be passed down an extra function (called the continuation) as argument, |
678 | * and pass the result to the continuation instead. |
679 | * |
680 | * Thus, every function call has to be passed an extra argument |
681 | * that represent the rest of the computation (Hence the name of continuation). |
682 | * |
683 | * Similarly, all other compute will be wrapped and call the continuation as well. |
684 | * |
685 | * \param f the function. |
686 | * \param mod the module. |
687 | * |
688 | * \return the converted Function. |
689 | */ |
690 | TVM_DLL Function ToCPS(const Function& f, const IRModule& mod); |
691 | |
692 | /*! |
693 | * \brief Remove the continuation argument of a CPS function. |
694 | * |
695 | * Note that this only transform the type back into un-CPS form |
696 | * when there is no higher order input/output. |
697 | * |
698 | * \param f the function. |
699 | * |
700 | * \return the converted Function. |
701 | */ |
702 | TVM_DLL Function UnCPS(const Function& f); |
703 | |
704 | /*! |
705 | * \brief Deduplicate the bound variables and type variables in the expression. |
706 | * |
707 | * \param e the expression. |
708 | * |
709 | * \return the deduplicated expression. |
710 | */ |
711 | TVM_DLL Expr DeDup(const Expr& e); |
712 | |
713 | namespace legalize { |
714 | TVM_DLL Expr Legalize(const Expr& expr, const std::string& legalize_map_attr_name); |
715 | } // namespace legalize |
716 | |
717 | } // namespace relay |
718 | } // namespace tvm |
719 | |
720 | #endif // TVM_RELAY_TRANSFORM_H_ |
721 | |