1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/transforms/compiler_function_utils.cc |
22 | * \brief Helper passes for working with functions with the "Compiler" attribute. |
23 | */ |
24 | |
25 | #include "./compiler_function_utils.h" |
26 | |
27 | #include "tvm/relay/analysis.h" |
28 | #include "tvm/relay/expr_functor.h" |
29 | #include "tvm/relay/transform.h" |
30 | |
31 | namespace tvm { |
32 | namespace relay { |
33 | namespace transform { |
34 | namespace { |
35 | |
36 | /*! |
37 | * \brief Returns the \p FunctionNode of if \p expr if it is a "Compiler" function which should |
38 | * be processed by a pass using \p compiler_filter. Otherwise returns null. |
39 | */ |
40 | const FunctionNode* AsFunctionNode(const Expr& expr, const std::string& compiler_filter) { |
41 | if (const auto* function_node = expr.as<FunctionNode>()) { |
42 | Optional<String> opt_compiler = function_node->GetAttr<String>(attr::kCompiler); |
43 | if (opt_compiler.defined() && |
44 | (compiler_filter.empty() || opt_compiler.value() == compiler_filter)) { |
45 | return function_node; |
46 | } |
47 | } |
48 | return nullptr; |
49 | } |
50 | |
51 | /*! |
52 | * \brief Rewrite calls to inlined and let-bound "Compiler" functions to global functions. The given |
53 | * module will be extended with the newly outlined functions. |
54 | */ |
55 | class Outliner : public MixedModeMutator { |
56 | public: |
57 | using MixedModeMutator::VisitExpr_; |
58 | |
59 | Outliner(GlobalSymbolCache* cache, std::string compiler_filter, IRModule mod) |
60 | : cache_(cache), compiler_filter_(std::move(compiler_filter)), mod_(std::move(mod)) {} |
61 | |
62 | Expr VisitExpr_(const LetNode* op) final { |
63 | auto pre_visit = [this](const LetNode* op) { |
64 | Expr var = this->VisitExpr(op->var); |
65 | Expr value = this->VisitExpr(op->value); |
66 | |
67 | if (AsFunctionNode(value, compiler_filter_)) { |
68 | // Inline on-the-fly if the let-bound value is a function of interest. |
69 | this->memo_[var] = value; |
70 | } |
71 | }; |
72 | auto post_visit = [this](const LetNode* op) { |
73 | // Rely on the Memoizer to cache pre-visit values |
74 | Expr value = this->VisitExpr(op->value); |
75 | Expr body = this->VisitExpr(op->body); |
76 | auto expr = GetRef<Expr>(op); |
77 | |
78 | if (AsFunctionNode(value, compiler_filter_)) { |
79 | // The let binding is no longer needed since inlined on-the-fly above. |
80 | this->memo_[expr] = this->VisitExpr(op->body); |
81 | } else { |
82 | Var var = Downcast<Var>(this->VisitExpr(op->var)); |
83 | if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { |
84 | this->memo_[expr] = expr; |
85 | } else { |
86 | this->memo_[expr] = Let(var, value, body); |
87 | } |
88 | } |
89 | }; |
90 | ExpandANormalForm(op, pre_visit, post_visit); |
91 | return memo_[GetRef<Expr>(op)]; |
92 | } |
93 | |
94 | Expr Rewrite_(const CallNode* pre, const Expr& post) final { |
95 | Call new_call = Downcast<Call>(post); |
96 | if (const auto* function_node = AsFunctionNode(new_call->op, compiler_filter_)) { |
97 | auto function = GetRef<Function>(function_node); |
98 | DCHECK(FreeVars(function).empty()) << "Function marked with '" << attr::kCompiler |
99 | << "' attribute should not have free variables" ; |
100 | // Ask the cache to supply a unique global var for this function. |
101 | GlobalVar global_symbol = cache_->GetGlobalSymbol(function); |
102 | // Depending on the cache's implementation, two structurally equal (but not object |
103 | // equal) functions may be assigned the same global symbol. If so we'll lift it just |
104 | // once, but rewrite all the calls. |
105 | if (!mod_->ContainGlobalVar(global_symbol->name_hint)) { |
106 | function = |
107 | WithAttr(std::move(function), tvm::attr::kGlobalSymbol, global_symbol->name_hint); |
108 | mod_->Add(global_symbol, function); |
109 | } |
110 | // Update the call. |
111 | return WithFields(new_call, global_symbol); |
112 | } |
113 | return post; |
114 | } |
115 | |
116 | private: |
117 | /*! |
118 | * \brief A cached mapping from functions to global variables. Depending on the implementation |
119 | * the cache may generate fresh symbols or require the function to already have a |
120 | * "global_symbol" attribute, and may share symbols between structurally equal functions. |
121 | */ |
122 | GlobalSymbolCache* cache_; |
123 | /*! \brief If non-empty, the "Compiler" attribute value to require on functions to outline. */ |
124 | std::string compiler_filter_; |
125 | /*! \brief Module being rewritten. */ |
126 | IRModule mod_; |
127 | }; |
128 | |
129 | /*! |
130 | * \brief Inline immediate calls to "Composite" functions. |
131 | */ |
132 | class InnerInliner : public MixedModeMutator { |
133 | public: |
134 | InnerInliner() = default; |
135 | |
136 | private: |
137 | using MixedModeMutator::Rewrite_; |
138 | |
139 | Expr Rewrite_(const CallNode* pre, const Expr& post) final { |
140 | Call new_call = Downcast<Call>(post); |
141 | if (const auto* function_node = new_call->op.as<FunctionNode>()) { |
142 | ICHECK(function_node->GetAttr<String>(attr::kComposite).defined()); |
143 | ICHECK_EQ(function_node->params.size(), new_call->args.size()); |
144 | Map<Var, Expr> subst; |
145 | for (size_t i = 0; i < new_call->args.size(); ++i) { |
146 | subst.Set(function_node->params[i], new_call->args[i]); |
147 | } |
148 | return Bind(function_node->body, subst); |
149 | } |
150 | return post; |
151 | } |
152 | }; |
153 | |
154 | /*! |
155 | * \brief Inline calls to global "Compiler" functions with global var in \p global_vars. |
156 | * Both the 'outer' "Compiler" function and any 'inner' "Composite" functions in its body |
157 | * are inlined. |
158 | */ |
159 | class OuterInliner : public MixedModeMutator { |
160 | public: |
161 | OuterInliner(IRModule mod, Array<GlobalVar> global_vars_) |
162 | : mod_(std::move(mod)), global_vars_(std::move(global_vars_)) {} |
163 | |
164 | private: |
165 | using MixedModeMutator::Rewrite_; |
166 | |
167 | Expr Rewrite_(const CallNode* pre, const Expr& post) final { |
168 | Call new_call = Downcast<Call>(post); |
169 | if (const auto* global_var_node = new_call->op.as<GlobalVarNode>()) { |
170 | auto global_var = GetRef<GlobalVar>(global_var_node); |
171 | if (std::find(global_vars_.begin(), global_vars_.end(), global_var) != global_vars_.end()) { |
172 | BaseFunc base_func = mod_->Lookup(global_var); |
173 | const auto* function_node = base_func.as<FunctionNode>(); |
174 | ICHECK(function_node); |
175 | ICHECK(function_node->GetAttr<String>(attr::kCompiler).defined()); |
176 | ICHECK_EQ(function_node->params.size(), new_call->args.size()); |
177 | Map<Var, Expr> subst; |
178 | for (size_t i = 0; i < new_call->args.size(); ++i) { |
179 | subst.Set(function_node->params[i], new_call->args[i]); |
180 | } |
181 | Expr new_body = InnerInliner().VisitExpr(function_node->body); |
182 | return Bind(new_body, subst); |
183 | } |
184 | } |
185 | return post; |
186 | } |
187 | |
188 | private: |
189 | /*! \brief Original module we are processing. */ |
190 | IRModule mod_; |
191 | /*! \brief Global vars of functions to inline. */ |
192 | Array<GlobalVar> global_vars_; |
193 | }; |
194 | |
195 | } // namespace |
196 | |
197 | GlobalSymbolCache::~GlobalSymbolCache() = default; |
198 | |
199 | GlobalVar ExistingGlobalSymbolCache::GetGlobalSymbol(const Function& function) { |
200 | Optional<String> opt_global_symbol = function->GetAttr<String>(tvm::attr::kGlobalSymbol); |
201 | ICHECK(opt_global_symbol.defined()) |
202 | << "ExistingGlobalSymbolCache requires all functions to already have a '" |
203 | << tvm::attr::kGlobalSymbol << "' attribute" ; |
204 | std::string global_symbol = opt_global_symbol.value(); |
205 | auto itr = global_vars_.find(global_symbol); |
206 | if (itr != global_vars_.end()) { |
207 | return itr->second; |
208 | } |
209 | // Ok if function does not have a checked_type, but if it does capture it in the global var. |
210 | GlobalVar global_var(global_symbol, function->checked_type_, function->span); |
211 | global_vars_.emplace(global_symbol, global_var); |
212 | return global_var; |
213 | } |
214 | |
215 | tvm::transform::Pass OutlineCompilerFunctions(std::shared_ptr<GlobalSymbolCache> cache, |
216 | std::string compiler_filter) { |
217 | runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func = |
218 | [cache = std::move(cache), compiler_filter = std::move(compiler_filter)]( |
219 | IRModule mod, transform::PassContext ctx) { |
220 | VLOG(1) << "OutlineCompilerFunctions input:" << std::endl << PrettyPrint(mod); |
221 | IRModule output_mod = mod->ShallowCopy(); |
222 | for (const auto& kv : mod->functions) { |
223 | if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) { |
224 | Expr new_body = |
225 | Outliner(cache.get(), compiler_filter, output_mod).VisitExpr(function_node->body); |
226 | Function new_function = |
227 | WithFields(GetRef<Function>(function_node), /*opt_params=*/{}, new_body); |
228 | output_mod->Add(kv.first, new_function); |
229 | } |
230 | } |
231 | VLOG(1) << "OutlineCompilerFunctions result:" << std::endl << PrettyPrint(output_mod); |
232 | return output_mod; |
233 | }; |
234 | |
235 | return tvm::transform::CreateModulePass(pass_func, 0, "OutlineCompilerFunctions" , {}); |
236 | } |
237 | |
238 | // Any Java programmers in the house? |
239 | tvm::transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols( |
240 | std::string compiler_filter) { |
241 | return OutlineCompilerFunctions(std::make_shared<ExistingGlobalSymbolCache>(), |
242 | std::move(compiler_filter)); |
243 | } |
244 | |
245 | tvm::transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter) { |
246 | runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func = |
247 | [compiler_filter = std::move(compiler_filter)](IRModule mod, transform::PassContext ctx) { |
248 | VLOG(1) << "MarkCompilerFunctionsAsExtern input:" << std::endl << PrettyPrint(mod); |
249 | IRModule output_mod = mod->ShallowCopy(); |
250 | for (const auto& kv : mod->functions) { |
251 | if (const auto* function_node = AsFunctionNode(kv.second, compiler_filter)) { |
252 | auto new_function = |
253 | WithFields(GetRef<Function>(function_node), function_node->params, |
254 | function_node->body, function_node->ret_type, function_node->type_params, |
255 | /* erase attributes */ DictAttrs(Map<String, ObjectRef>())); |
256 | new_function = WithAttr(std::move(new_function), attr::kExtern, Integer(1)); |
257 | output_mod->Update(kv.first, new_function); |
258 | } |
259 | } |
260 | VLOG(1) << "MarkCompilerFunctionsAsExtern result:" << std::endl << PrettyPrint(output_mod); |
261 | return output_mod; |
262 | }; |
263 | |
264 | return tvm::transform::CreateModulePass(pass_func, 0, "MarkCompilerFunctionsAsExtern" , {}); |
265 | } |
266 | |
267 | tvm::transform::Pass InlineCompilerFunctionsBoundTo(Array<GlobalVar> global_vars) { |
268 | runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func = |
269 | [global_vars = std::move(global_vars)](IRModule mod, transform::PassContext ctx) { |
270 | VLOG(1) << "InlineCompilerFunctionsBoundTo with global_vars: " << PrettyPrint(global_vars); |
271 | if (global_vars.empty()) { |
272 | return mod; |
273 | } |
274 | VLOG(1) << "InlineCompilerFunctions input:" << std::endl << PrettyPrint(mod); |
275 | IRModule output_mod = mod->ShallowCopy(); |
276 | for (const auto& kv : mod->functions) { |
277 | if (std::find(global_vars.begin(), global_vars.end(), kv.first) != global_vars.end()) { |
278 | output_mod->Remove(kv.first); |
279 | } else if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) { |
280 | Expr new_body = OuterInliner(mod, global_vars).VisitExpr(function_node->body); |
281 | Function new_function = |
282 | WithFields(GetRef<Function>(function_node), /*opt_params=*/{}, new_body); |
283 | output_mod->Add(kv.first, new_function); |
284 | } |
285 | } |
286 | VLOG(1) << "InlineCompilerFunctionsBoundTo result:" << std::endl << PrettyPrint(output_mod); |
287 | return output_mod; |
288 | }; |
289 | |
290 | return tvm::transform::CreateModulePass(pass_func, 0, "InlineCompilerFunctionsBoundTo" , {}); |
291 | } |
292 | |
293 | TVM_REGISTER_GLOBAL("relay._transform.OutlineCompilerFunctionsWithExistingGlobalSymbols" ) |
294 | .set_body_typed(OutlineCompilerFunctionsWithExistingGlobalSymbols); |
295 | TVM_REGISTER_GLOBAL("relay._transform.MarkCompilerFunctionsAsExtern" ) |
296 | .set_body_typed(MarkCompilerFunctionsAsExtern); |
297 | TVM_REGISTER_GLOBAL("relay._transform.InlineCompilerFunctionsBoundTo" ) |
298 | .set_body_typed(InlineCompilerFunctionsBoundTo); |
299 | |
300 | } // namespace transform |
301 | } // namespace relay |
302 | } // namespace tvm |
303 | |