1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file relay/ir/transform.cc |
22 | * \brief Relay specific transformation passes. |
23 | */ |
24 | #include <dmlc/thread_local.h> |
25 | #include <tvm/node/repr_printer.h> |
26 | #include <tvm/relay/transform.h> |
27 | #include <tvm/runtime/registry.h> |
28 | |
29 | namespace tvm { |
30 | namespace relay { |
31 | namespace transform { |
32 | |
33 | TVM_REGISTER_PASS_CONFIG_OPTION("relay.fallback_device_type" , IntImm); |
34 | |
35 | class FunctionPass; |
36 | |
37 | /*! |
38 | * \brief Function-level passes are used to implement various global |
39 | * optimizations for a given Relay module. It fetches one function at a time |
40 | * from the function list in the module for optimization. |
41 | * |
42 | * Note that the scope of passes at this level is a Relay function. Therefore, |
43 | * we cannot add or delete a function through these passes as they are not aware |
44 | * of the global information. |
45 | */ |
46 | class FunctionPassNode : public PassNode { |
47 | public: |
48 | /* \brief The pass meta data.*/ |
49 | PassInfo pass_info; |
50 | |
51 | /*! \brief The packed pass function sketches the real optimization. For |
52 | * instance, we can implement a pass that works on a Relay function as a |
53 | * `pass_func` and let it run on a given module. The same `pass_func` will |
54 | * then be applied on each function in the module. |
55 | */ |
56 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func; |
57 | |
58 | FunctionPassNode() = default; |
59 | |
60 | void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info" , &pass_info); } |
61 | |
62 | /*! |
63 | * \brief Run a function pass on given pass context. |
64 | * |
65 | * \param mod The module that an optimization pass is applied on. |
66 | * \param mod The context that an optimization pass executes on. |
67 | * |
68 | * \return Return the updated module. |
69 | */ |
70 | IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final; |
71 | |
72 | /*! |
73 | * \brief Get the pass information/meta data. |
74 | */ |
75 | PassInfo Info() const override { return pass_info; } |
76 | |
77 | static constexpr const char* _type_key = "relay.FunctionPass" ; |
78 | TVM_DECLARE_FINAL_OBJECT_INFO(FunctionPassNode, PassNode); |
79 | }; |
80 | |
81 | class FunctionPass : public Pass { |
82 | public: |
83 | /*! |
84 | * \brief The constructor |
85 | * \param pass_func The packed function which implements a pass. |
86 | * \param pass_info The pass info. |
87 | */ |
88 | TVM_DLL FunctionPass( |
89 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func, |
90 | PassInfo pass_info); |
91 | |
92 | TVM_DEFINE_OBJECT_REF_METHODS(FunctionPass, Pass, FunctionPassNode); |
93 | }; |
94 | |
95 | FunctionPass::FunctionPass( |
96 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func, |
97 | PassInfo pass_info) { |
98 | auto n = make_object<FunctionPassNode>(); |
99 | n->pass_func = std::move(pass_func); |
100 | n->pass_info = std::move(pass_info); |
101 | data_ = std::move(n); |
102 | } |
103 | |
104 | // Perform Module -> Module optimizations at the Function level. |
105 | IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const { |
106 | DiagnosticContext previous = DiagnosticContext::Default(mod); |
107 | |
108 | if (pass_ctx->diag_ctx) { |
109 | DiagnosticContext tmp = pass_ctx->diag_ctx.value(); |
110 | pass_ctx->diag_ctx = previous; |
111 | previous = tmp; |
112 | } else { |
113 | pass_ctx->diag_ctx = previous; |
114 | } |
115 | |
116 | ICHECK(pass_ctx->diag_ctx) |
117 | << "The diagnostic context was set at the top of this block this is a bug." ; |
118 | |
119 | const PassInfo& pass_info = Info(); |
120 | |
121 | ICHECK(mod.defined()); |
122 | |
123 | VLOG_CONTEXT << pass_info->name; |
124 | VLOG(0) << "Executing function pass with opt level: " << pass_info->opt_level; |
125 | VLOG(1) << "Input module:" << std::endl << PrettyPrint(mod); |
126 | |
127 | IRModule updated_mod = mod->ShallowCopy(); |
128 | |
129 | std::vector<std::pair<GlobalVar, Function>> updates; |
130 | for (const auto& kv : mod->functions) { |
131 | // only process optimizable Relay Functions |
132 | if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) { |
133 | Function updated_func = pass_func(GetRef<Function>(function_node), updated_mod, pass_ctx); |
134 | updates.push_back({kv.first, std::move(updated_func)}); |
135 | } |
136 | } |
137 | |
138 | for (const auto& pair : updates) { |
139 | updated_mod->Add(pair.first, pair.second, true); |
140 | } |
141 | |
142 | ICHECK(pass_ctx->diag_ctx) |
143 | << "The diagnostic context was set at the top of this block this is a bug." ; |
144 | |
145 | pass_ctx->diag_ctx.value().Render(); |
146 | pass_ctx->diag_ctx = previous; |
147 | |
148 | VLOG(1) << "Output module:" << std::endl << PrettyPrint(updated_mod); |
149 | |
150 | // TODO(@jroesch): move away from eager type checking for performance reasons |
151 | // make issue. |
152 | return transform::InferType()(updated_mod); |
153 | } |
154 | |
155 | Pass CreateFunctionPass( |
156 | const runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>& pass_func, |
157 | int opt_level, String name, tvm::Array<String> required) { |
158 | PassInfo pass_info = PassInfo(opt_level, name, required); |
159 | return FunctionPass(pass_func, pass_info); |
160 | } |
161 | |
162 | TVM_REGISTER_NODE_TYPE(FunctionPassNode); |
163 | |
164 | TVM_REGISTER_GLOBAL("relay._transform.MakeFunctionPass" ) |
165 | .set_body_typed( |
166 | [](runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func, |
167 | PassInfo pass_info) { return FunctionPass(pass_func, pass_info); }); |
168 | |
169 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
170 | .set_dispatch<FunctionPassNode>([](const ObjectRef& ref, ReprPrinter* p) { |
171 | auto* node = static_cast<const FunctionPassNode*>(ref.get()); |
172 | const PassInfo info = node->Info(); |
173 | p->stream << "Run Function pass: " << info->name << " at the optimization level " |
174 | << info->opt_level; |
175 | }); |
176 | |
177 | } // namespace transform |
178 | } // namespace relay |
179 | } // namespace tvm |
180 | |