1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/relay/transform.h
22 * \brief Relay specific transformation passes.
23 */
24#ifndef TVM_RELAY_TRANSFORM_H_
25#define TVM_RELAY_TRANSFORM_H_
26
27#include <tvm/ir/transform.h>
28#include <tvm/relay/attrs/transform.h>
29#include <tvm/relay/expr.h>
30#include <tvm/relay/function.h>
31#include <tvm/relay/op.h>
32#include <tvm/relay/op_attr_types.h>
33#include <tvm/target/compilation_config.h>
34#include <tvm/target/target.h>
35#include <tvm/target/virtual_device.h>
36
37#include <string>
38
39namespace tvm {
40namespace relay {
41namespace transform {
42
43using Pass = tvm::transform::Pass;
44using PassNode = tvm::transform::PassNode;
45using PassInfo = tvm::transform::PassInfo;
46using PassInfoNode = tvm::transform::PassInfoNode;
47using PassContext = tvm::transform::PassContext;
48using PassContextNode = tvm::transform::PassContextNode;
49using Sequential = tvm::transform::Sequential;
50
51/*
52 * \brief Create a function pass.
53 *
54 * \param pass_func The packed function that contains the optimization.
55 * \param opt_level The optimization level of the function pass.
56 * \param name The name of the function pass.
57 * \param required The list of the passes that the function pass is dependent on.
58 *
59 * \return The created function pass.
60 */
61TVM_DLL Pass CreateFunctionPass(
62 const runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>& pass_func,
63 int opt_level, String name, tvm::Array<String> required);
64
65/*! \brief Remove let-bound expressions which do not effect the program result.
66 *
67 * This pass will remove let bindings which are not referenced. If inline_once is True,
68 * let bindings which are only referenced once will also be inlined.
69 *
70 * For example, this pass should turn `let a = 1; 2` into `2`,
71 * as the value of the expression does not depend on a.
72 *
73 * As another example, `let a = 1; a` will be optimized into 1 if inline_once is True.
74 *
75 * If ignore_purity is False, possibly side-effecting expressions (such as memory allocation,
76 * random number generation, reading/writing references, or calls to primitive or external
77 * functions) are never elided or inlined. This is sound, but ignore_purity can be set to True
78 * to suppress this check.
79 *
80 * The analysis is fairly conservative, for example it assumes all local functions
81 * may be called more than once, any functions passed as arguments have side effects,
82 * and so on.
83 *
84 * \param inline_once whether or not to inline bindings used exactly once.
85 * \param ignore_purity whether to ignore whether expressions have side-effects
86 *
87 * \return the pass.
88 */
89TVM_DLL Pass DeadCodeElimination(bool inline_once = false, bool ignore_purity = false);
90
91/*!
92 * \brief Convert all expressions of TensorType into GradCell,
93 * an algebraic data type defined in gradient.rly.
94 *
95 * This will delay or decrease memory usage. All calls to
96 * ones, ones_like, zeros, zeros_like will not immediately instantiate a tensor in memory,
97 * rather only instantiate if needed. It also defines + and * operation
98 * between GradCell types which can increase performance when using
99 * zero-filled or one-filled tensors, which is the case in reverse mode ad.
100 *
101 * \return the pass
102 */
103TVM_DLL Pass LazyGradientInit();
104
105/*!
106 * \brief Fold constant expressions.
107 *
108 * Because of backward compatibility reason it skips QNN primitives from folding by default.
109 * There are some transformation passes like FakeQuantizationToInteger, which requires to keep QNN
110 * primitives for constant subgraphs. Uncontrolled constant folding of QNN primitives may break
111 * applicability of FakeQuantizationToInteger. We suggest to use FoldConstant pass with none
112 * default fold_qnn=True value only when all other QNN sensitive passes were already applied.
113 *
114 * \param fold_qnn Whether to fold constants for QNN operations.
115 *
116 * \return The pass.
117 */
118TVM_DLL Pass FoldConstant(bool fold_qnn = false);
119
120/*!
121 * \brief Split function with huge number of arguments to smaller pieces.
122 *
123 * \return The pass.
124 */
125TVM_DLL Pass SplitArgs(int max_function_args);
126
127/*!
128 * \brief Fuse operations into expr into separate functions.
129 *
130 * \param fuse_opt_level Optimization level. If it is -1 it will be inferred from pass context.
131 *
132 * \return The pass.
133 */
134TVM_DLL Pass FuseOps(int fuse_opt_level = -1);
135
136/*!
137 * \brief The inverse operation of FuseOps. It transforms a fused program returned by
138 * FuseOps into the program before FuseOps. (i.e. x == DefuseOps(FuseOps(x)))
139 *
140 * \return The pass.
141 */
142TVM_DLL Pass DefuseOps();
143
144/*!
145 * \brief Rewrite the annotated program.
146 *
147 * \param fallback_device The fallback device which is the default device for
148 * operators without annotation.
149 *
150 * \return The pass.
151 */
152TVM_DLL Pass RewriteAnnotatedOps(int fallback_device);
153
154/*!
155 * \brief Turn an expression to Basic Block Normal Form.
156 *
157 * We define a block as a group of expressions implied by the scope structure.
158 *
159 * Each graph node can only belong to a single block.
160 *
161 * For any value that is being used in multiple blocks, it has to be referred
162 * by a Var which is defined in a block, whose scope is the least common ancestor
163 * of blocks this value is used.
164 *
165 * \return The pass.
166 */
167TVM_DLL Pass ToBasicBlockNormalForm();
168
169/*!
170 * \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
171 *
172 * It will turn an expression that is in a graph form (with sharing implicit),
173 * to an expression with explicit sharing (A-Normal Form).
174 *
175 * The scope of the root expression is the global scope.
176 *
177 * The scope of any non root expression is the least common ancestor of all it's scope.
178 *
179 * Values are ordered by post-DFS order in each scope.
180 *
181 * \return The pass.
182 */
183TVM_DLL Pass ToANormalForm();
184
185/*!
186 * \brief ToANormalForm but on incomplete graph.
187 *
188 * \param expr the graph.
189 *
190 * \return The transformed program.
191 */
192TVM_DLL Expr ToANormalForm(const Expr& expr);
193
194/*!
195 * \brief Turn an expression into continuation passing style(CPS).
196 *
197 * CPS mean that every function will, instead of returning the result directly,
198 * be passed down an extra function (called the continuation) as argument,
199 * and pass the result to the continuation instead.
200 *
201 * Thus, every function call has to be passed an extra argument
202 * that represent the rest of the computation (Hence the name of continuation).
203 *
204 * Similarly, all other compute will be wrapped and call the continuation as well.
205 *
206 * \return the pass.
207 */
208TVM_DLL Pass ToCPS();
209
210/*!
211 * \brief Remove let binding and directly share via pointer instead.
212 *
213 * It will remove all let binding,
214 * and turn all of the variable bound by let into direct pointer reference.
215 *
216 * \return the expression in graph normal form.
217 */
218TVM_DLL Pass ToGraphNormalForm();
219
220/*!
221 * \brief Aggressive constant propagation/constant folding/inlining.
222 *
223 * It will do as much computation in compile time as possible.
224 * It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
225 * As a side effect, code size will explode.
226 *
227 * \return the optimized expression.
228 */
229TVM_DLL Pass PartialEval();
230
231/*!
232 * \brief Simplify certain operators during inference. For example, the result
233 * of a batch norm which is indexed at tuple index 0 will be unpacked into a
234 * number of simplified operators.
235 *
236 * \return The Pass.
237 */
238TVM_DLL Pass SimplifyInference();
239
240/*!
241 * \brief Replaces non linear activation functions with their fast but approximate counterparts.
242 *
243 * \return The Pass.
244 */
245TVM_DLL Pass FastMath();
246
247/*!
248 * \brief Find Dynamic ops and make them static
249 *
250 * Searches the graph for dynamic ops. If the dynamic inputs to those ops are constants, it replaces
251 * them with static ops and re-performs type inference and constant folding. The pass repeats
252 * itself until the graph stops changing or we run too many iterations.
253 *
254 * \return The pass.
255 */
256TVM_DLL Pass DynamicToStatic();
257
258/*!
259 * \brief Infer the type of an expression.
260 *
261 * The result of type checking is a new expression with unambiguous
262 * type information filled in, as well as it's checked type field
263 * populated with the result type.
264 *
265 * \return The pass.
266 */
267TVM_DLL Pass InferType();
268
269/*!
270 * \brief Infer the type of an expression, reusing existing type information.
271 *
272 * The result of type checking is a new expression with unambiguous
273 * type information filled in for the given node only. The local
274 * version can use existing type information populated throughout
275 * the expression and assumes this information is correct. The local
276 * version also avoids examining large amounts of the graph assuming
277 * type information is filled in properly which makes it much faster if we
278 * iteratively call type inference.
279 *
280 * \return The type of the expression.
281 */
282TVM_DLL Type InferTypeLocal(const Expr& expr);
283
284/*!
285 * \brief Search and eliminate common subexpression. For example, if there are
286 * two expressions evaluated to an identical value, a single variable is created
287 * and these two expressions are replaced by this variable.
288 *
289 * \param fskip The callback argument that allows to skip certain expressions.
290 *
291 * \return The pass.
292 */
293TVM_DLL Pass EliminateCommonSubexpr(runtime::PackedFunc fskip = nullptr);
294
295/*!
296 * \brief Combine parallel 2d convolutions into a single convolution if the
297 * number of branches of this conv2d operator is not less than
298 * `min_num_branch`.
299 *
300 * \param min_num_branches The minimun number of branches.
301 *
302 * \return The pass.
303 */
304TVM_DLL Pass CombineParallelConv2D(uint64_t min_num_branches = 3);
305
306/*!
307 * \brief Combine parallel dense ops into a single batch_matmul if the
308 * number of branches of this dense operator is not less than
309 * `min_num_branch`.
310 *
311 * \param min_num_branches The minimun number of branches.
312 * \param to_batch_matmul Whether to combine parallel dense ops to batch matmul.
313 * If set false, combine dense ops to single dense op.
314 *
315 * \return The pass.
316 */
317TVM_DLL Pass CombineParallelDense(uint64_t min_num_branches = 3, bool to_batch_matmul = true);
318
319/*!
320 * \brief Combine parallel batch_matmul ops into a single batch_matmul
321 * if the number of branches of this dense operator is not less than
322 * `min_num_branch`.
323 *
324 * \param min_num_branches The minimun number of branches.
325 *
326 * \return The pass.
327 */
328TVM_DLL Pass CombineParallelBatchMatmul(uint64_t min_num_branches = 3);
329
330/*!
331 * \brief Backward fold axis scaling into weights of conv/dense operators.
332 *
333 * \return The pass.
334 */
335TVM_DLL Pass BackwardFoldScaleAxis();
336
337/*!
338 * \brief Forward fold axis scaling into weights of conv/dense operators.
339 *
340 * \return The pass.
341 */
342TVM_DLL Pass ForwardFoldScaleAxis();
343
344/*!
345 * \brief A sequential pass that executes ForwardFoldScaleAxis and
346 * BackwardFoldScaleAxis passes.
347 *
348 * \return The pass.
349 */
350TVM_DLL Pass FoldScaleAxis();
351
352/*!
353 * \brief Canonicalize some operators to the simplified operators. For example,
354 * bias_add can be canonicalized to expand_dims and broadcast_add.
355 *
356 * \return The pass.
357 */
358TVM_DLL Pass CanonicalizeOps();
359
360/*!
361 * \brief Alternate the layouts of operators or replace primitive operators
362 * with other expressions.
363 *
364 * \return The pass.
365 */
366TVM_DLL Pass AlterOpLayout();
367
368/*!
369 * \brief Do layout rewrite according to the tile structure created by auto-scheduler.
370 * \return The pass
371 */
372TVM_DLL Pass AutoSchedulerLayoutRewrite();
373
374/*!
375 * \brief Do layout rewrite according to the tile structure created by meta-schedule.
376 * \return The pass
377 */
378TVM_DLL Pass MetaScheduleLayoutRewrite();
379
380/*!
381 * \brief Given a dest layout, this pass transforms the expr such that most of the ops input data
382 * layout is changed to the dest layout. In ideal situation, there are only 2 layout transforms, one
383 * at the start and one at the end.
384 *
385 * This pass is not a part of relay.build and is expected to be called between framework-relay
386 * parser and relay.build call. This is very helpful for hardware backends that support/prefer only
387 * type of data layout.
388 *
389 * RFC - https://discuss.tvm.ai/t/layout-conversion-pass/4009
390 *
391 * This pass uses most of the AlterOpLayout and InferCorrectLayout infrastructure. We can define new
392 * layouts for conv2d ops for now. Most of the other operators try to adapt to their input layout
393 * using the InferCorrectLayout infrastructure.
394 *
395 * \param desired_layouts Specify mapping of op_name to array of desired layouts for each input.
396 * For example: Map("nn.conv2d", Array("NHWC", "OHWI")),
397 * this specifies the desired layout for data then kernel for nn.conv2d.
398 * \return The pass.
399 */
400TVM_DLL Pass ConvertLayout(const Map<String, Array<String>>& desired_layouts);
401
402/*!
403 * \brief Legalizes an expr with another expression.
404 * \param legalize_map_attr_name The Op's attr name which corresponds to the legalize rule function.
405 * One can collect and isolate similar type of legalize transformations using this param. For
406 * example, transformations that only apply to Dialects can be isolated into a FTVMDialectLegalize
407 * string. This pass calls only those transformations that have been registered using the supplied
408 * legalize_map_attr_name.
409 *
410 * \return The pass.
411 */
412TVM_DLL Pass Legalize(const String& legalize_map_attr_name = "FTVMLegalize");
413
414/*!
415 * \brief Canonicalize cast expressions to make operator fusion more efficient.
416 *
417 * \return The pass.
418 */
419TVM_DLL Pass CanonicalizeCast();
420
421/*!
422 * \brief Add abstraction over a constructor or global variable bound to a function.
423 *
424 * For example: `square` is transformed to
425 * `fn (%x: int32) -> int32 { square(x) }`.
426 *
427 * See https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B7-conversion
428 * for more details.
429 *
430 * \param expand_constructor Whether to expand constructors.
431 * \param expand_global_var Whether to expand global variables.
432 *
433 * \return The pass.
434 */
435TVM_DLL Pass EtaExpand(bool expand_constructor, bool expand_global_var);
436
437/*!
438 * \brief Partition a Relay program into regions that can be executed on
439 * different backends.
440 *
441 * \return The pass.
442 */
443TVM_DLL Pass PartitionGraph();
444
445/*!
446 * \brief Inline the global functions marked as `inline` in a given Relay
447 * IRModule.
448 *
449 * \return The pass.
450 */
451TVM_DLL Pass Inline();
452
453/*!
454 * \brief Remove the unused functions in the Relay IRModule.
455 *
456 * \param entry_functions The entry functions used to search the functions that
457 * are being used.
458 *
459 * \return The pass.
460 */
461TVM_DLL Pass RemoveUnusedFunctions(Array<runtime::String> entry_functions);
462
463/*!
464 * \brief Simplify the Relay expression.
465 *
466 * \return The pass.
467 */
468TVM_DLL Pass SimplifyExpr();
469
470/*!
471 * \brief Run any custom passes registered under "RelayToTIR" attributes on TargetKinds.
472 *
473 * This pass looks for inline, let-bound or global functions which have a "Compiler" attribute.
474 * If the attribute value corresponds to a TargetKind with a "RelayToTIR" attribute, then the
475 * 'custom' pass bound to that attribute is run (at most once) on the IRModule as a whole.
476 *
477 * If, in addition, the \p config has a Target with a matching TargetKind, that Target is set
478 * as the 'current' target before the custom pass is executed. In this way it is possible
479 * for custom passes to pick up target options which may guide how they transform the IRModule.
480 * (Those targets are referred to as 'extern codegen targets' elsewhere).
481 *
482 * A typical custom pass will:
483 * - Find calls to "Compiler" attributes functions with matching compiler name.
484 * - Lower those function to TIR PrimFuncs.
485 * - Bind those functions into the IRModule under the the functions' "global_symbol" attribute.
486 * - Replace all calls to those functions with 'call_lowered' to the matching global.
487 * Care should be taken to handle multiple calls to the same function.
488 * See src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc for an example custom pass.
489 *
490 * It is also possible (despite the pass and attribute names!) for the custom pass to proceed
491 * directly to a runtime::Module, which can be attached to the output IRModules "external_mods"
492 * attribute (taking care not to clobber any existing modules). In this case the flow is as above,
493 * except:
494 * - The runtime::Module must contain a binding for each compiled function under their
495 * "global_symbol" (ie runtime::Module::ImplementsFunction should return true).
496 * - A Relay Function must be bound (or re-bound) into the result IRModule, again with the same
497 * "global_symbol", but with only the "Extern" attribute set to Integer(1). The function body
498 * should be the original function body. In this way we always have a TVM definition matching
499 * every global function name.
500 *
501 * There are many existing runtime::Modules, ranging from source to object to dynamic libaries to
502 * entirely custom implementations. Some of those may require additional compilation using
503 * 'export_library' on the final build artifact.
504 *
505 * The OutlineCompilerFunctionsWithExistingGlobalSymbols and MarkCompilerFunctionsAsExtern utility
506 * passes can be used by custom passes to take care of some of the boilerplate.
507 *
508 * TODO(mbs): Rename PreLoweringTargetHooks?
509 *
510 * \param config All available targets.
511 *
512 * \return The pass.
513 */
514TVM_DLL Pass RelayToTIRTargetHook(CompilationConfig config);
515
516/*!
517 * \brief A pass for manifesting explicit memory allocations and rewriting
518 * specific dialects.
519 *
520 * \param cpu_virtual_device VirtualDevice for computations and data which must reside on a CPU,
521 * such as shapes and shape functions.
522 *
523 * \return The pass.
524 */
525TVM_DLL Pass ManifestAlloc(VirtualDevice cpu_virtual_device);
526
527/*!
528 * \brief A pass for manifesting variable lifetimes by inserting kill operations when variables
529 * become dead. This pass should be run after ManifestAlloc, and should not be run more than once.
530 *
531 * \return The pass.
532 */
533TVM_DLL Pass ManifestLifetimes();
534
535/*!
536 * \brief Uses existing "on_device" and "device_copy" CallNodes to infer the \p VirtualDevice on
537 * which every Relay sub-expression should run and the result stored. Captures the result of that
538 * analysis using new "on_device" and "device_copy" CallNodes.
539 *
540 * See tvm::relay::transform::{LexicalOnDeviceMixin,DeviceAwareExprVisitor,DeviceAwareExprMutator}
541 * for help recovering the device for an arbitrary sub-expression in downstream transformations.
542 *
543 * \param config Describes the targets and default \p VirtualDevice for all primitive operators and
544 * host sub-expressions.
545 *
546 * \return The pass.
547 */
548TVM_DLL Pass PlanDevices(CompilationConfig config);
549
550/*!
551 * \brief This transform flattens atrous convolution, which corresponds to the sequence of
552 * operations: "space_to_batch_nd"->"conv2d"->"batch_to_space_nd" and convert them into subgraphs
553 * with a convolution with the modified "dilation" and recalculated "padding" parameters.
554 *
555 * \return The pass.
556 */
557TVM_DLL Pass FlattenAtrousConv();
558
559/*!
560 * \brief Annotates the minimum required memory of each primitive function callsite by analyzing
561 * the liveness of the input/output tensors at each function callsite and calculating the total
562 * amount of memory these tensors require. This is added as a "used_memory" annotation to the
563 * function in question as a list of the number of bytes for each callsite. In addition, the
564 * containing function is annotated with an "io_used_memory" annotation which refers to the total
565 * memory required for the IO tensors.
566 *
567 * Note: This pass does not support dynamic shapes, it is the users responsibility to check this
568 * pass isn't applied where dynamic shapes may be input.
569 */
570TVM_DLL Pass AnnotateUsedMemory();
571
572/*!
573 * \brief Captures the post-dfs index and dominator post-dfs index of (most) expression nodes in
574 * their span, in the form "index:<post-dfs index>:<dominator post-dfs index>". This is useful for
575 * debugging since a) it helps identify pretty-printed sub-expressions within the overall model
576 * and b) the indexes are heavily used by Collage for its compact representation of sub-graphs.
577 *
578 * Note that Op and Constructor nodes are not changed even though they are assigned an
579 * post-dfs index.
580 */
581TVM_DLL Pass CapturePostDfsIndexInSpans();
582
583/*!
584 * \brief Calls device dependent memory scope analysis pass, collects mapping of desirable
585 * expr->memory_scope and annotates expressions by VirtualDevice with required memory_scope
586 */
587TVM_DLL Pass AnnotateMemoryScope();
588
589/*!
590 * \brief Removes non-fused reshapes after lowering the graph.
591 * InferType() cannot be invoked after calling this pass as it removes reshapes from the call
592 * graph. Many targets only need buffer addresses irrespective of the shapes of them. This makes
593 * reshapes symbolic once the graph has been lowered. Reshape removal results into smaller code
594 * size and reduced buffer allocations. It opens up opportunities of operator fusion in the target
595 * backend. Thus, consequently, it improves the performance of the inference.
596 */
597TVM_DLL Pass RemoveStandaloneReshapes();
598
599} // namespace transform
600
601/*!
602 * \brief Bind the free variables to a Relay expression. This is a helper
603 * function usually called by other pass functions to help optimizations.
604 * If any free variables are introduced into a function, those are added
605 * to the functoin parameters.
606 * Additionally this may change the order of parameters if you map a variable
607 * to a variable.
608 *
609 * \param expr The input expression.
610 * \param binds The variable to expression map that will be used to help the
611 * binding.
612 *
613 * \return The updated expression.
614 */
615TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds);
616
617/*!
618 * \brief Substitute variables with new variables (including function parameters) in a function.
619 * This is a helper function usually called by other pass functions to help optimizations.
620 * Expects all values in the bind map to be Vars.
621 *
622 * \param func The input function.
623 * \param binds The variable to expression map that will be used to help the
624 * binding.
625 *
626 * \return The updated expression.
627 */
628TVM_DLL Function SubstituteBoundVars(const Function& func, const tvm::Map<Var, Expr>& binds);
629
630/*!
631 * \brief Apply rewrite rules to rewrite the expr in post DFS order. This
632 * function is used as a helper function to rewrtie an expression in a pass.
633 *
634 * \param expr The expression.
635 * \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite
636 * rule function.
637 * \param fcontext Additional callback to provide context argument for each call node.
638 * \param fmulti_ref_trigger Transformation function to be called when
639 * an Expr consumed by multiple callers.
640 * \return The rewritten expression.
641 */
642TVM_DLL Expr ForwardRewrite(const Expr& expr, const String& rewrite_map_attr_name,
643 std::function<ObjectRef(const Call&)> fcontext = nullptr,
644 std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
645
646/*!
647 * \brief Apply rewrite rules to rewrite the expr in post DFS order. This
648 * function is used as a helper function to rewrtie an expression in a pass.
649 *
650 * \param expr The expression.
651 * \param rewrite_func The rewrite func that will apply to all operators.
652 * \param fcontext Additional callback to provide context argument for each call node.
653 * \param fmulti_ref_trigger Transformation function to be called when
654 * an Expr consumed by multiple callers.
655 *
656 * \return The rewritten expression.
657 */
658TVM_DLL Expr ForwardRewrite(const Expr& expr, const FForwardRewrite& rewrite_func,
659 std::function<ObjectRef(const Call&)> fcontext = nullptr,
660 std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
661
662/*!
663 * \brief Rewrite the annotated program.
664 *
665 * \param expr The expression.
666 * \param fallback_device The fallback device which is the default device for
667 * operators without annotation.
668 *
669 * \return The updated program.
670 */
671TVM_DLL Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device);
672
673/*!
674 * \brief Turn an expression into continuation passing style(CPS).
675 *
676 * CPS mean that every function will, instead of returning the result directly,
677 * be passed down an extra function (called the continuation) as argument,
678 * and pass the result to the continuation instead.
679 *
680 * Thus, every function call has to be passed an extra argument
681 * that represent the rest of the computation (Hence the name of continuation).
682 *
683 * Similarly, all other compute will be wrapped and call the continuation as well.
684 *
685 * \param f the function.
686 * \param mod the module.
687 *
688 * \return the converted Function.
689 */
690TVM_DLL Function ToCPS(const Function& f, const IRModule& mod);
691
692/*!
693 * \brief Remove the continuation argument of a CPS function.
694 *
695 * Note that this only transform the type back into un-CPS form
696 * when there is no higher order input/output.
697 *
698 * \param f the function.
699 *
700 * \return the converted Function.
701 */
702TVM_DLL Function UnCPS(const Function& f);
703
704/*!
705 * \brief Deduplicate the bound variables and type variables in the expression.
706 *
707 * \param e the expression.
708 *
709 * \return the deduplicated expression.
710 */
711TVM_DLL Expr DeDup(const Expr& e);
712
713namespace legalize {
714TVM_DLL Expr Legalize(const Expr& expr, const std::string& legalize_map_attr_name);
715} // namespace legalize
716
717} // namespace relay
718} // namespace tvm
719
720#endif // TVM_RELAY_TRANSFORM_H_
721