1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file relay/ir/transform.cc
22 * \brief Relay specific transformation passes.
23 */
24#include <dmlc/thread_local.h>
25#include <tvm/node/repr_printer.h>
26#include <tvm/relay/transform.h>
27#include <tvm/runtime/registry.h>
28
29namespace tvm {
30namespace relay {
31namespace transform {
32
33TVM_REGISTER_PASS_CONFIG_OPTION("relay.fallback_device_type", IntImm);
34
35class FunctionPass;
36
37/*!
38 * \brief Function-level passes are used to implement various global
39 * optimizations for a given Relay module. It fetches one function at a time
40 * from the function list in the module for optimization.
41 *
42 * Note that the scope of passes at this level is a Relay function. Therefore,
43 * we cannot add or delete a function through these passes as they are not aware
44 * of the global information.
45 */
46class FunctionPassNode : public PassNode {
47 public:
48 /* \brief The pass meta data.*/
49 PassInfo pass_info;
50
51 /*! \brief The packed pass function sketches the real optimization. For
52 * instance, we can implement a pass that works on a Relay function as a
53 * `pass_func` and let it run on a given module. The same `pass_func` will
54 * then be applied on each function in the module.
55 */
56 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func;
57
58 FunctionPassNode() = default;
59
60 void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); }
61
62 /*!
63 * \brief Run a function pass on given pass context.
64 *
65 * \param mod The module that an optimization pass is applied on.
66 * \param mod The context that an optimization pass executes on.
67 *
68 * \return Return the updated module.
69 */
70 IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;
71
72 /*!
73 * \brief Get the pass information/meta data.
74 */
75 PassInfo Info() const override { return pass_info; }
76
77 static constexpr const char* _type_key = "relay.FunctionPass";
78 TVM_DECLARE_FINAL_OBJECT_INFO(FunctionPassNode, PassNode);
79};
80
81class FunctionPass : public Pass {
82 public:
83 /*!
84 * \brief The constructor
85 * \param pass_func The packed function which implements a pass.
86 * \param pass_info The pass info.
87 */
88 TVM_DLL FunctionPass(
89 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func,
90 PassInfo pass_info);
91
92 TVM_DEFINE_OBJECT_REF_METHODS(FunctionPass, Pass, FunctionPassNode);
93};
94
95FunctionPass::FunctionPass(
96 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func,
97 PassInfo pass_info) {
98 auto n = make_object<FunctionPassNode>();
99 n->pass_func = std::move(pass_func);
100 n->pass_info = std::move(pass_info);
101 data_ = std::move(n);
102}
103
104// Perform Module -> Module optimizations at the Function level.
105IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const {
106 DiagnosticContext previous = DiagnosticContext::Default(mod);
107
108 if (pass_ctx->diag_ctx) {
109 DiagnosticContext tmp = pass_ctx->diag_ctx.value();
110 pass_ctx->diag_ctx = previous;
111 previous = tmp;
112 } else {
113 pass_ctx->diag_ctx = previous;
114 }
115
116 ICHECK(pass_ctx->diag_ctx)
117 << "The diagnostic context was set at the top of this block this is a bug.";
118
119 const PassInfo& pass_info = Info();
120
121 ICHECK(mod.defined());
122
123 VLOG_CONTEXT << pass_info->name;
124 VLOG(0) << "Executing function pass with opt level: " << pass_info->opt_level;
125 VLOG(1) << "Input module:" << std::endl << PrettyPrint(mod);
126
127 IRModule updated_mod = mod->ShallowCopy();
128
129 std::vector<std::pair<GlobalVar, Function>> updates;
130 for (const auto& kv : mod->functions) {
131 // only process optimizable Relay Functions
132 if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) {
133 Function updated_func = pass_func(GetRef<Function>(function_node), updated_mod, pass_ctx);
134 updates.push_back({kv.first, std::move(updated_func)});
135 }
136 }
137
138 for (const auto& pair : updates) {
139 updated_mod->Add(pair.first, pair.second, true);
140 }
141
142 ICHECK(pass_ctx->diag_ctx)
143 << "The diagnostic context was set at the top of this block this is a bug.";
144
145 pass_ctx->diag_ctx.value().Render();
146 pass_ctx->diag_ctx = previous;
147
148 VLOG(1) << "Output module:" << std::endl << PrettyPrint(updated_mod);
149
150 // TODO(@jroesch): move away from eager type checking for performance reasons
151 // make issue.
152 return transform::InferType()(updated_mod);
153}
154
155Pass CreateFunctionPass(
156 const runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>& pass_func,
157 int opt_level, String name, tvm::Array<String> required) {
158 PassInfo pass_info = PassInfo(opt_level, name, required);
159 return FunctionPass(pass_func, pass_info);
160}
161
162TVM_REGISTER_NODE_TYPE(FunctionPassNode);
163
164TVM_REGISTER_GLOBAL("relay._transform.MakeFunctionPass")
165 .set_body_typed(
166 [](runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func,
167 PassInfo pass_info) { return FunctionPass(pass_func, pass_info); });
168
169TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
170 .set_dispatch<FunctionPassNode>([](const ObjectRef& ref, ReprPrinter* p) {
171 auto* node = static_cast<const FunctionPassNode*>(ref.get());
172 const PassInfo info = node->Info();
173 p->stream << "Run Function pass: " << info->name << " at the optimization level "
174 << info->opt_level;
175 });
176
177} // namespace transform
178} // namespace relay
179} // namespace tvm
180