1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/transforms/compiler_function_utils.cc
22 * \brief Helper passes for working with functions with the "Compiler" attribute.
23 */
24
25#include "./compiler_function_utils.h"
26
27#include "tvm/relay/analysis.h"
28#include "tvm/relay/expr_functor.h"
29#include "tvm/relay/transform.h"
30
31namespace tvm {
32namespace relay {
33namespace transform {
34namespace {
35
36/*!
37 * \brief Returns the \p FunctionNode of if \p expr if it is a "Compiler" function which should
38 * be processed by a pass using \p compiler_filter. Otherwise returns null.
39 */
40const FunctionNode* AsFunctionNode(const Expr& expr, const std::string& compiler_filter) {
41 if (const auto* function_node = expr.as<FunctionNode>()) {
42 Optional<String> opt_compiler = function_node->GetAttr<String>(attr::kCompiler);
43 if (opt_compiler.defined() &&
44 (compiler_filter.empty() || opt_compiler.value() == compiler_filter)) {
45 return function_node;
46 }
47 }
48 return nullptr;
49}
50
51/*!
52 * \brief Rewrite calls to inlined and let-bound "Compiler" functions to global functions. The given
53 * module will be extended with the newly outlined functions.
54 */
55class Outliner : public MixedModeMutator {
56 public:
57 using MixedModeMutator::VisitExpr_;
58
59 Outliner(GlobalSymbolCache* cache, std::string compiler_filter, IRModule mod)
60 : cache_(cache), compiler_filter_(std::move(compiler_filter)), mod_(std::move(mod)) {}
61
62 Expr VisitExpr_(const LetNode* op) final {
63 auto pre_visit = [this](const LetNode* op) {
64 Expr var = this->VisitExpr(op->var);
65 Expr value = this->VisitExpr(op->value);
66
67 if (AsFunctionNode(value, compiler_filter_)) {
68 // Inline on-the-fly if the let-bound value is a function of interest.
69 this->memo_[var] = value;
70 }
71 };
72 auto post_visit = [this](const LetNode* op) {
73 // Rely on the Memoizer to cache pre-visit values
74 Expr value = this->VisitExpr(op->value);
75 Expr body = this->VisitExpr(op->body);
76 auto expr = GetRef<Expr>(op);
77
78 if (AsFunctionNode(value, compiler_filter_)) {
79 // The let binding is no longer needed since inlined on-the-fly above.
80 this->memo_[expr] = this->VisitExpr(op->body);
81 } else {
82 Var var = Downcast<Var>(this->VisitExpr(op->var));
83 if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) {
84 this->memo_[expr] = expr;
85 } else {
86 this->memo_[expr] = Let(var, value, body);
87 }
88 }
89 };
90 ExpandANormalForm(op, pre_visit, post_visit);
91 return memo_[GetRef<Expr>(op)];
92 }
93
94 Expr Rewrite_(const CallNode* pre, const Expr& post) final {
95 Call new_call = Downcast<Call>(post);
96 if (const auto* function_node = AsFunctionNode(new_call->op, compiler_filter_)) {
97 auto function = GetRef<Function>(function_node);
98 DCHECK(FreeVars(function).empty()) << "Function marked with '" << attr::kCompiler
99 << "' attribute should not have free variables";
100 // Ask the cache to supply a unique global var for this function.
101 GlobalVar global_symbol = cache_->GetGlobalSymbol(function);
102 // Depending on the cache's implementation, two structurally equal (but not object
103 // equal) functions may be assigned the same global symbol. If so we'll lift it just
104 // once, but rewrite all the calls.
105 if (!mod_->ContainGlobalVar(global_symbol->name_hint)) {
106 function =
107 WithAttr(std::move(function), tvm::attr::kGlobalSymbol, global_symbol->name_hint);
108 mod_->Add(global_symbol, function);
109 }
110 // Update the call.
111 return WithFields(new_call, global_symbol);
112 }
113 return post;
114 }
115
116 private:
117 /*!
118 * \brief A cached mapping from functions to global variables. Depending on the implementation
119 * the cache may generate fresh symbols or require the function to already have a
120 * "global_symbol" attribute, and may share symbols between structurally equal functions.
121 */
122 GlobalSymbolCache* cache_;
123 /*! \brief If non-empty, the "Compiler" attribute value to require on functions to outline. */
124 std::string compiler_filter_;
125 /*! \brief Module being rewritten. */
126 IRModule mod_;
127};
128
129/*!
130 * \brief Inline immediate calls to "Composite" functions.
131 */
132class InnerInliner : public MixedModeMutator {
133 public:
134 InnerInliner() = default;
135
136 private:
137 using MixedModeMutator::Rewrite_;
138
139 Expr Rewrite_(const CallNode* pre, const Expr& post) final {
140 Call new_call = Downcast<Call>(post);
141 if (const auto* function_node = new_call->op.as<FunctionNode>()) {
142 ICHECK(function_node->GetAttr<String>(attr::kComposite).defined());
143 ICHECK_EQ(function_node->params.size(), new_call->args.size());
144 Map<Var, Expr> subst;
145 for (size_t i = 0; i < new_call->args.size(); ++i) {
146 subst.Set(function_node->params[i], new_call->args[i]);
147 }
148 return Bind(function_node->body, subst);
149 }
150 return post;
151 }
152};
153
154/*!
155 * \brief Inline calls to global "Compiler" functions with global var in \p global_vars.
156 * Both the 'outer' "Compiler" function and any 'inner' "Composite" functions in its body
157 * are inlined.
158 */
159class OuterInliner : public MixedModeMutator {
160 public:
161 OuterInliner(IRModule mod, Array<GlobalVar> global_vars_)
162 : mod_(std::move(mod)), global_vars_(std::move(global_vars_)) {}
163
164 private:
165 using MixedModeMutator::Rewrite_;
166
167 Expr Rewrite_(const CallNode* pre, const Expr& post) final {
168 Call new_call = Downcast<Call>(post);
169 if (const auto* global_var_node = new_call->op.as<GlobalVarNode>()) {
170 auto global_var = GetRef<GlobalVar>(global_var_node);
171 if (std::find(global_vars_.begin(), global_vars_.end(), global_var) != global_vars_.end()) {
172 BaseFunc base_func = mod_->Lookup(global_var);
173 const auto* function_node = base_func.as<FunctionNode>();
174 ICHECK(function_node);
175 ICHECK(function_node->GetAttr<String>(attr::kCompiler).defined());
176 ICHECK_EQ(function_node->params.size(), new_call->args.size());
177 Map<Var, Expr> subst;
178 for (size_t i = 0; i < new_call->args.size(); ++i) {
179 subst.Set(function_node->params[i], new_call->args[i]);
180 }
181 Expr new_body = InnerInliner().VisitExpr(function_node->body);
182 return Bind(new_body, subst);
183 }
184 }
185 return post;
186 }
187
188 private:
189 /*! \brief Original module we are processing. */
190 IRModule mod_;
191 /*! \brief Global vars of functions to inline. */
192 Array<GlobalVar> global_vars_;
193};
194
195} // namespace
196
197GlobalSymbolCache::~GlobalSymbolCache() = default;
198
199GlobalVar ExistingGlobalSymbolCache::GetGlobalSymbol(const Function& function) {
200 Optional<String> opt_global_symbol = function->GetAttr<String>(tvm::attr::kGlobalSymbol);
201 ICHECK(opt_global_symbol.defined())
202 << "ExistingGlobalSymbolCache requires all functions to already have a '"
203 << tvm::attr::kGlobalSymbol << "' attribute";
204 std::string global_symbol = opt_global_symbol.value();
205 auto itr = global_vars_.find(global_symbol);
206 if (itr != global_vars_.end()) {
207 return itr->second;
208 }
209 // Ok if function does not have a checked_type, but if it does capture it in the global var.
210 GlobalVar global_var(global_symbol, function->checked_type_, function->span);
211 global_vars_.emplace(global_symbol, global_var);
212 return global_var;
213}
214
215tvm::transform::Pass OutlineCompilerFunctions(std::shared_ptr<GlobalSymbolCache> cache,
216 std::string compiler_filter) {
217 runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
218 [cache = std::move(cache), compiler_filter = std::move(compiler_filter)](
219 IRModule mod, transform::PassContext ctx) {
220 VLOG(1) << "OutlineCompilerFunctions input:" << std::endl << PrettyPrint(mod);
221 IRModule output_mod = mod->ShallowCopy();
222 for (const auto& kv : mod->functions) {
223 if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) {
224 Expr new_body =
225 Outliner(cache.get(), compiler_filter, output_mod).VisitExpr(function_node->body);
226 Function new_function =
227 WithFields(GetRef<Function>(function_node), /*opt_params=*/{}, new_body);
228 output_mod->Add(kv.first, new_function);
229 }
230 }
231 VLOG(1) << "OutlineCompilerFunctions result:" << std::endl << PrettyPrint(output_mod);
232 return output_mod;
233 };
234
235 return tvm::transform::CreateModulePass(pass_func, 0, "OutlineCompilerFunctions", {});
236}
237
238// Any Java programmers in the house?
239tvm::transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(
240 std::string compiler_filter) {
241 return OutlineCompilerFunctions(std::make_shared<ExistingGlobalSymbolCache>(),
242 std::move(compiler_filter));
243}
244
245tvm::transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter) {
246 runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
247 [compiler_filter = std::move(compiler_filter)](IRModule mod, transform::PassContext ctx) {
248 VLOG(1) << "MarkCompilerFunctionsAsExtern input:" << std::endl << PrettyPrint(mod);
249 IRModule output_mod = mod->ShallowCopy();
250 for (const auto& kv : mod->functions) {
251 if (const auto* function_node = AsFunctionNode(kv.second, compiler_filter)) {
252 auto new_function =
253 WithFields(GetRef<Function>(function_node), function_node->params,
254 function_node->body, function_node->ret_type, function_node->type_params,
255 /* erase attributes */ DictAttrs(Map<String, ObjectRef>()));
256 new_function = WithAttr(std::move(new_function), attr::kExtern, Integer(1));
257 output_mod->Update(kv.first, new_function);
258 }
259 }
260 VLOG(1) << "MarkCompilerFunctionsAsExtern result:" << std::endl << PrettyPrint(output_mod);
261 return output_mod;
262 };
263
264 return tvm::transform::CreateModulePass(pass_func, 0, "MarkCompilerFunctionsAsExtern", {});
265}
266
267tvm::transform::Pass InlineCompilerFunctionsBoundTo(Array<GlobalVar> global_vars) {
268 runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
269 [global_vars = std::move(global_vars)](IRModule mod, transform::PassContext ctx) {
270 VLOG(1) << "InlineCompilerFunctionsBoundTo with global_vars: " << PrettyPrint(global_vars);
271 if (global_vars.empty()) {
272 return mod;
273 }
274 VLOG(1) << "InlineCompilerFunctions input:" << std::endl << PrettyPrint(mod);
275 IRModule output_mod = mod->ShallowCopy();
276 for (const auto& kv : mod->functions) {
277 if (std::find(global_vars.begin(), global_vars.end(), kv.first) != global_vars.end()) {
278 output_mod->Remove(kv.first);
279 } else if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) {
280 Expr new_body = OuterInliner(mod, global_vars).VisitExpr(function_node->body);
281 Function new_function =
282 WithFields(GetRef<Function>(function_node), /*opt_params=*/{}, new_body);
283 output_mod->Add(kv.first, new_function);
284 }
285 }
286 VLOG(1) << "InlineCompilerFunctionsBoundTo result:" << std::endl << PrettyPrint(output_mod);
287 return output_mod;
288 };
289
290 return tvm::transform::CreateModulePass(pass_func, 0, "InlineCompilerFunctionsBoundTo", {});
291}
292
293TVM_REGISTER_GLOBAL("relay._transform.OutlineCompilerFunctionsWithExistingGlobalSymbols")
294 .set_body_typed(OutlineCompilerFunctionsWithExistingGlobalSymbols);
295TVM_REGISTER_GLOBAL("relay._transform.MarkCompilerFunctionsAsExtern")
296 .set_body_typed(MarkCompilerFunctionsAsExtern);
297TVM_REGISTER_GLOBAL("relay._transform.InlineCompilerFunctionsBoundTo")
298 .set_body_typed(InlineCompilerFunctionsBoundTo);
299
300} // namespace transform
301} // namespace relay
302} // namespace tvm
303