1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/transforms/partition_graph.cc |
22 | * |
23 | * \brief Partition an input function into multiple functions according based |
24 | * on the inserted annotation nodes (i.e. compiler_begin and compiler_end). |
25 | * These nodes are used as boundaries to partition the Relay function into |
26 | * multiple regions that can be offloaded to different accelerators/backends. |
27 | * |
28 | * Each of these paritioned functions, a.k.a regions, will be viewed as |
29 | * external functions, and they will use the provided compiler for codegen. |
30 | */ |
31 | |
32 | #include <tvm/ir/module.h> |
33 | #include <tvm/relay/analysis.h> |
34 | #include <tvm/relay/attrs/annotation.h> |
35 | #include <tvm/relay/error.h> |
36 | #include <tvm/relay/expr.h> |
37 | #include <tvm/relay/expr_functor.h> |
38 | #include <tvm/relay/transform.h> |
39 | #include <tvm/runtime/name_transforms.h> |
40 | |
41 | #include <unordered_map> |
42 | #include <unordered_set> |
43 | #include <utility> |
44 | #include <vector> |
45 | |
46 | #include "../analysis/annotated_region_set.h" |
47 | #include "../backend/name_transforms.h" |
48 | #include "../backend/utils.h" |
49 | #include "pass_utils.h" |
50 | |
51 | namespace tvm { |
52 | namespace relay { |
53 | |
54 | namespace partitioning { |
55 | |
56 | /*! \brief This struct maintains the required metadata for a region to generate a corresponding |
57 | * global function and function call. Global function will be passed to the target specific codegen |
58 | * and function call will be used in the transform Relay graph to invoke the function in runtime. |
59 | */ |
60 | struct RegionFuncMetadata { |
61 | /*! \brief The call node of the generated global function for this region. */ |
62 | Call func_call; |
63 | |
64 | /*! \brief A list of argument pairs. Each pair includes (var, expr). var is used |
65 | * as a function node argument; input expression is used as a function call parameter. |
66 | */ |
67 | std::vector<std::pair<Var, Expr>> args; |
68 | |
69 | /*! \brief Map from each region output expr (compiler end) node to |
70 | * the corresponding function output expr. |
71 | */ |
72 | std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual> region_func_out; |
73 | |
74 | /*! \brief Map from each region input expression (compiler begin) to |
75 | * the corresponding function input variable. This cache is used to make sure |
76 | * a region function will not have duplicated inputs even if it refers to |
77 | * the same expr multiple times. |
78 | */ |
79 | std::unordered_map<Expr, Var, ObjectPtrHash, ObjectPtrEqual> region_func_in; |
80 | }; |
81 | |
82 | /*! \brief This class partitions the expr labeled with begin and end annotations |
83 | * into function containing multiple regions. Each region is labeled with |
84 | * a compiler attribute so that it will be handled by any compilers that are not |
85 | * in the TVM stack. |
86 | * |
87 | * Input : A Relay module that has functions with disjoint annotated regions |
88 | * using compiler_begin and compiler_end. There could be multiple |
89 | * outputs. |
90 | * |
91 | * Output : A Relay module with global functions for such disjoint annotated |
92 | * regions with calls inserted at the respective location |
93 | * |
94 | * Dependencies : AnnotatedRegionSet Utility class. |
95 | * |
96 | * Methodology : |
97 | * 1) The AnnotatedRegionSet utility class is able to construct a collection |
98 | * of nodes that are bound by a given annotation -- here we use |
99 | * compiler_begin and compiler_end |
100 | * 2) Initially, for each function in the module RegionSets are populated. |
101 | * 3) Then, Vistor pass is traversed until a compiler_end node is encountered |
102 | * that belongs to a "region". |
103 | * 4) When the first compiler_end of a given annotated region is found, |
104 | * a function is formed and inserted. |
105 | * a) if the region has multiple outputs, a Tuple node (capturing |
106 | * all outputs) is returned. |
107 | * 5) Thereafter, if we encounter an another output of the same annotated |
108 | * region, it is important to note that the function is already formed. |
109 | * Therefore, it will lookup the function and add a TupleGetItemNode. |
110 | * a) We will use the location index of "rets" of each Region" of |
111 | * AnnotatedRegionSet as TupleGetItemNode index. |
112 | * 6) Therefore, functions will be created for all annotated regions. |
113 | * The name for each global function is created using "Region" id and |
114 | * the compiler name. |
115 | */ |
116 | |
117 | class Partitioner : public MixedModeMutator { |
118 | public: |
119 | Partitioner(const IRModule& module, bool bind_constants) |
120 | : module_(module), bind_constants_(bind_constants) { |
121 | std::set<std::string> func_names; |
122 | for (auto f : module->functions) { |
123 | GlobalVar f_var = f.first; |
124 | BaseFunc f_func = f.second; |
125 | std::string f_name = f_var.as<GlobalVarNode>()->name_hint; |
126 | while (func_names.find(f_name) != func_names.end()) { |
127 | f_name += "_a" ; |
128 | } |
129 | func_names.insert(f_name); |
130 | |
131 | // Creating regionset per function in the module. |
132 | auto region_set = |
133 | AnnotatedRegionSet::Create(f_func, CompilerBeginOp(), CompilerEndOp(), f_name); |
134 | regions_sets_[region_set] = f_func; |
135 | } |
136 | } |
137 | |
138 | Expr Rewrite_(const CallNode* call, const Expr& post) final { |
139 | auto op_node = call->op.as<OpNode>(); |
140 | if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) { |
141 | return post; |
142 | } else if (call->op == CompilerBeginOp()) { |
143 | // The annotation node is inserted on edge so it must have only one argument. |
144 | ICHECK_EQ(call->args.size(), 1U); |
145 | |
146 | // Traverse the rest graph. |
147 | Expr parent = call->args[0]; |
148 | auto input_expr = Downcast<Call>(post)->args[0]; |
149 | |
150 | // Backtrace the parent to find the first ancestor node that is not a begin or end op |
151 | while (const auto* parent_call = parent.as<CallNode>()) { |
152 | if (parent_call->op == CompilerBeginOp() || parent_call->op == CompilerEndOp()) { |
153 | parent = parent_call->args[0]; |
154 | } else { |
155 | break; |
156 | } |
157 | } |
158 | |
159 | AnnotatedRegion sg = GetRegion(GetRef<Call>(call)); |
160 | int index = GetArgIdx(sg, GetRef<Call>(call)); |
161 | ICHECK_NE(index, -1); |
162 | |
163 | if (region_func_meta_[sg].region_func_in.count(parent)) { |
164 | return region_func_meta_[sg].region_func_in[parent]; |
165 | } else { |
166 | // The type of the created variable is the same as the compiler_begin |
167 | // node. |
168 | std::string target = call->attrs.as<CompilerAttrs>()->compiler; |
169 | std::string varname = |
170 | target + "_" + std::to_string(sg->GetID()) + "_i" + std::to_string(index); |
171 | auto var = Var(varname, GetRef<Call>(call)->checked_type_); |
172 | |
173 | std::pair<Var, Expr> cand = std::make_pair(var, input_expr); |
174 | |
175 | if (std::find(region_func_meta_[sg].args.begin(), region_func_meta_[sg].args.end(), cand) == |
176 | region_func_meta_[sg].args.end()) { |
177 | region_func_meta_[sg].args.push_back(cand); |
178 | } |
179 | region_func_meta_[sg].region_func_in[parent] = var; |
180 | return std::move(var); |
181 | } |
182 | } else { |
183 | ICHECK_EQ(call->op, CompilerEndOp()); |
184 | // The annotation node is inserted on edge so it must have only one |
185 | // argument. |
186 | ICHECK_EQ(call->args.size(), 1U); |
187 | |
188 | AnnotatedRegion region = GetRegion(GetRef<Call>(call)); |
189 | |
190 | // TODO(@manupa-arm) : need to use the parent function (to which region |
191 | // belongs to) name/key for the functions that are created |
192 | BaseFunc f = GetFunc(GetRef<Call>(call)); |
193 | |
194 | // Traverse subgraph inputs. |
195 | auto input = Downcast<Call>(post)->args[0]; |
196 | ICHECK(region.defined()) << "Region not defined for " << GetRef<Call>(call); |
197 | // functions are created for each annotated regions, |
198 | // when their first output is encountered. |
199 | // If multiple outputs are there, a tuple node is inserted at the end. |
200 | |
201 | if (!region_func_meta_[region].func_call.defined()) { |
202 | // First time this region is encountered in the traversal. Creating the function. |
203 | CreateFunction(region, call); |
204 | } |
205 | |
206 | // Retrieve this particular output of function. |
207 | Expr region_out_expr = Downcast<Call>(GetRef<Call>(call))->args[0]; |
208 | ICHECK(region_func_meta_[region].region_func_out.count(region_out_expr)); |
209 | return region_func_meta_[region].region_func_out[region_out_expr]; |
210 | } |
211 | } |
212 | |
213 | IRModule Partition() { |
214 | auto glob_funcs = module_->functions; |
215 | for (const auto& pair : glob_funcs) { |
216 | if (auto* fn = pair.second.as<FunctionNode>()) { |
217 | Function func = GetRef<Function>(fn); |
218 | func = WithFields(func, func->params, VisitExpr(func->body)); |
219 | module_->Update(pair.first, func); |
220 | module_ = transform::InferType()(module_); |
221 | } |
222 | } |
223 | return module_; |
224 | } |
225 | |
226 | private: |
227 | /*! |
228 | * \brief Get the region an expression belongs to |
229 | * if its in a region. |
230 | */ |
231 | AnnotatedRegion GetRegion(const Expr& e) { |
232 | for (auto sg_set_it : regions_sets_) { |
233 | auto sg_set = sg_set_it.first; |
234 | AnnotatedRegion sg = sg_set->GetRegion(e); |
235 | if (sg.defined()) { |
236 | return sg; |
237 | } |
238 | } |
239 | return AnnotatedRegion(nullptr); |
240 | } |
241 | |
242 | /*! |
243 | * \brief Get the function an expression belongs to |
244 | * if its in a region. |
245 | */ |
246 | BaseFunc GetFunc(const Expr& e) { |
247 | for (auto sg_set_it : regions_sets_) { |
248 | auto sg_set = sg_set_it.first; |
249 | auto func = sg_set_it.second; |
250 | |
251 | AnnotatedRegion sg = sg_set->GetRegion(e); |
252 | if (sg.defined()) { |
253 | return func; |
254 | } |
255 | } |
256 | return BaseFunc(nullptr); |
257 | } |
258 | |
259 | /*! |
260 | * \brief Get the index of the argument; |
261 | * this is to be used as tuplegetitem idx |
262 | */ |
263 | int GetArgIdx(AnnotatedRegion sg, const Expr& arg) { |
264 | int idx = 0; |
265 | for (auto arg_ : sg->GetInputs()) { |
266 | if (arg == arg_) { |
267 | return idx; |
268 | } |
269 | idx++; |
270 | } |
271 | return -1; |
272 | } |
273 | |
274 | /*! |
275 | * \brief Check if an expr is a constant or a tuple that only contain constants. |
276 | */ |
277 | bool IsConstant(const Expr& expr) const { |
278 | if (expr->IsInstance<ConstantNode>()) return true; |
279 | if (!expr->IsInstance<TupleNode>()) return false; |
280 | const auto* tn = expr.as<TupleNode>(); |
281 | return std::all_of(tn->fields.begin(), tn->fields.end(), |
282 | [](const Expr& e) { return e->IsInstance<ConstantNode>(); }); |
283 | } |
284 | |
285 | /*! |
286 | * \brief Create a call to the function that represents a region. |
287 | * \note The customized optimization pipeline will be invoked as well to |
288 | * optimize each function that is handled by external codegen. |
289 | */ |
290 | Call CreateRegionCall(AnnotatedRegion region, const Array<Expr>& fields, |
291 | const CallNode* end_node) { |
292 | Array<Var> params; |
293 | Array<Expr> param_expr; |
294 | Map<Var, Expr> params_bind; |
295 | for (auto pair : region_func_meta_[region].args) { |
296 | params.push_back(pair.first); |
297 | if (bind_constants_ && IsConstant(pair.second)) { |
298 | params_bind.Set(pair.first, pair.second); |
299 | } else { |
300 | param_expr.push_back(pair.second); |
301 | } |
302 | } |
303 | |
304 | Function global_region_func; |
305 | if (fields.size() == 1) { |
306 | // If there are only a single output; no need to add a tuple |
307 | global_region_func = |
308 | Function(params, fields[0], end_node->args[0]->checked_type_, {}, DictAttrs()); |
309 | } else { |
310 | auto tuple = Tuple(fields); |
311 | global_region_func = Function(params, tuple, tuple->checked_type_, {}, DictAttrs()); |
312 | } |
313 | |
314 | std::string target = end_node->attrs.as<CompilerAttrs>()->compiler; |
315 | std::string name = target + "_" + region->GetName() + "_" + std::to_string(region->GetID()); |
316 | |
317 | // Constant propagation |
318 | if (!params_bind.empty()) { |
319 | global_region_func = Downcast<Function>(relay::Bind(global_region_func, params_bind)); |
320 | } |
321 | std::string ext_opt = "relay.ext." + target + ".optimize" ; |
322 | auto pf = tvm::runtime::Registry::Get(ext_opt); |
323 | if (pf != nullptr) { |
324 | auto mod = IRModule::FromExpr(global_region_func); |
325 | mod = transform::InferType()(mod); |
326 | mod = (*pf)(mod); |
327 | global_region_func = Downcast<Function>(mod->Lookup("main" )); |
328 | } |
329 | |
330 | global_region_func = |
331 | WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol, runtime::String(name)); |
332 | global_region_func = WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1)); |
333 | global_region_func = |
334 | WithAttr(std::move(global_region_func), attr::kCompiler, tvm::runtime::String(target)); |
335 | global_region_func = WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1)); |
336 | |
337 | GlobalVarSupply global_var_supply = GlobalVarSupply(module_); |
338 | GlobalVar glob_func = global_var_supply->FreshGlobal(name, false); |
339 | ICHECK(!module_->ContainGlobalVar(glob_func->name_hint)) |
340 | << "Global function " << glob_func->name_hint << " already exists" ; |
341 | // Create a global function and add it to the IRModule for the region. |
342 | // This way we lift the functions that should be handled by external |
343 | // codegen to the module scope and rely on the pass manager to prevent |
344 | // relay function level passes (i.e. simplify inference and fusion) |
345 | // optimizing it. |
346 | module_->Add(glob_func, global_region_func); |
347 | module_ = relay::transform::InferType()(module_); |
348 | |
349 | // Create a call node for the function. |
350 | auto call = Call(glob_func, param_expr); |
351 | region_func_meta_[region].func_call = call; |
352 | |
353 | return call; |
354 | } |
355 | |
356 | /*! |
357 | * \brief Create a function and its function call for the given region. If the function has |
358 | * multiple outputs, a Tuple will be formed to aggregate all outputs, and TupleGetItem nodes |
359 | * will be created to serve output consumers. |
360 | */ |
361 | void CreateFunction(AnnotatedRegion region, const CallNode* end_node) { |
362 | // Create fields which is a unique list of outputs. |
363 | Array<Expr> fields; |
364 | std::unordered_map<Expr, int, ObjectPtrHash, ObjectPtrEqual> out_expr_to_idx; |
365 | int out_idx = 0; |
366 | for (auto region_end_node : region->GetOutputs()) { |
367 | auto ret_node = Downcast<Call>(region_end_node)->args[0]; |
368 | // Don't duplicate outputs. |
369 | if (!out_expr_to_idx.count(ret_node)) { |
370 | auto ret_expr = MixedModeMutator::VisitExpr(ret_node); |
371 | fields.push_back(ret_expr); |
372 | out_expr_to_idx[ret_node] = out_idx++; |
373 | } |
374 | } |
375 | |
376 | Call call = CreateRegionCall(region, fields, end_node); |
377 | |
378 | // Create output expr(s) for the function call. |
379 | if (out_expr_to_idx.size() == 1) { |
380 | // Single output direcly uses the call node as the output expr. |
381 | region_func_meta_[region].region_func_out[out_expr_to_idx.begin()->first] = call; |
382 | } else { |
383 | // Multiple outptus need to create TupleGetItem nodes as output exprs. |
384 | for (auto pair : out_expr_to_idx) { |
385 | Expr region_out_expr = pair.first; // The arg of a compiler end node of this region. |
386 | int idx = pair.second; // Corresponding function output tuple index. |
387 | auto tuple_get_item = TupleGetItem(call, idx); |
388 | tuple_get_item->checked_type_ = region_out_expr->checked_type_; |
389 | region_func_meta_[region].region_func_out[region_out_expr] = tuple_get_item; |
390 | } |
391 | } |
392 | } |
393 | |
394 | /*! \brief Map from each region to its metadata of the generated function. */ |
395 | std::unordered_map<AnnotatedRegion, RegionFuncMetadata, ObjectPtrHash, ObjectPtrEqual> |
396 | region_func_meta_; |
397 | |
398 | /*! \brief Each region set is associated with a function in the module. |
399 | * This map maintains the mapping between regionsets and the function it |
400 | * belongs to |
401 | */ |
402 | std::unordered_map<AnnotatedRegionSet, BaseFunc, ObjectPtrHash, ObjectPtrEqual> regions_sets_; |
403 | |
404 | /*!\brief The IRModule used for partitioning. */ |
405 | IRModule module_; |
406 | |
407 | /*!\brief Whether or not to bind constants in partitioned subgraphs. */ |
408 | bool bind_constants_{false}; |
409 | }; |
410 | |
411 | IRModule RemoveDefaultAnnotations(IRModule module) { |
412 | class DefaultRemover : public ExprRewriter { |
413 | public: |
414 | DefaultRemover() = default; |
415 | |
416 | Expr Rewrite_(const CallNode* call, const Expr& post) final { |
417 | auto attrs = call->attrs.as<CompilerAttrs>(); |
418 | if (attrs != nullptr && attrs->compiler == "default" ) { |
419 | return Downcast<Call>(post)->args[0]; |
420 | } |
421 | return post; |
422 | } |
423 | }; |
424 | |
425 | auto glob_funcs = module->functions; |
426 | // module is mutable, hence, we make a copy of it. |
427 | module.CopyOnWrite(); |
428 | for (const auto& pair : glob_funcs) { |
429 | if (auto* fn = pair.second.as<FunctionNode>()) { |
430 | auto func = GetRef<Function>(fn); |
431 | DefaultRemover remover; |
432 | auto removed = PostOrderRewrite(func->body, &remover); |
433 | func = WithFields(func, func->params, removed); |
434 | module->Update(pair.first, func); |
435 | module = relay::transform::InferType()(module); |
436 | } |
437 | } |
438 | return module; |
439 | } |
440 | |
441 | /*! \brief There can be regions with multiple outputs where each output |
442 | * could be a tuple output. Such tuple outputs needs to be flattened |
443 | * otherwise the function would create tuples of tuples. Moreover, tuple |
444 | * of tuples are valid relay, however they are not currently supported by |
445 | * graph executor or relay VM. |
446 | */ |
447 | |
448 | // New annotations would be required to be added for each flattened output |
449 | static const PackedFunc* make_end_op = |
450 | runtime::Registry::Get("relay.op.annotation._make.compiler_end" ); |
451 | |
452 | IRModule FlattenTupleOutputs(IRModule module) { |
453 | class TupleOutFlattener : public ExprRewriter { |
454 | public: |
455 | TupleOutFlattener() = default; |
456 | |
457 | Expr Rewrite_(const CallNode* call, const Expr& post) final { |
458 | if (call->op == CompilerEndOp()) { |
459 | std::string target = call->attrs.as<CompilerAttrs>()->compiler; |
460 | // Arguments of annotation ops should be 1 |
461 | ICHECK_EQ(call->args.size(), 1U); |
462 | auto annotated_op = Downcast<Call>(post)->args[0]; |
463 | if (const auto* tuple_node = annotated_op.as<TupleNode>()) { |
464 | Array<Expr> new_fields; |
465 | new_fields.reserve(tuple_node->fields.size()); |
466 | |
467 | // Here each input of the tuple will be annotated with compiler_ends |
468 | for (auto& tn_arg : tuple_node->fields) { |
469 | new_fields.push_back((*make_end_op)(tn_arg, target)); |
470 | } |
471 | |
472 | // Return a tuple of compiler_ends in the place of the tuple that was |
473 | // annotated with a compiler_end. |
474 | return WithFields(GetRef<Tuple>(tuple_node), new_fields); |
475 | } |
476 | } |
477 | return post; |
478 | } |
479 | }; |
480 | |
481 | auto glob_funcs = module->functions; |
482 | // module is mutable, hence, we make a copy of it. |
483 | module.CopyOnWrite(); |
484 | for (const auto& pair : glob_funcs) { |
485 | if (auto* fn = pair.second.as<FunctionNode>()) { |
486 | Function func = GetRef<Function>(fn); |
487 | TupleOutFlattener to_flattener; |
488 | auto removed = PostOrderRewrite(func->body, &to_flattener); |
489 | func = WithFields(func, func->params, removed); |
490 | module->Update(pair.first, func); |
491 | module = relay::transform::InferType()(module); |
492 | } |
493 | } |
494 | return module; |
495 | } |
496 | |
497 | class NameMangleExtFuncs : public MixedModeMutator { |
498 | public: |
499 | explicit NameMangleExtFuncs(const IRModule& module, std::function<String(String)> mangle_fn) |
500 | : module_(module), mangle_fn_(mangle_fn) {} |
501 | |
502 | IRModule Run() { |
503 | auto glob_funcs = module_->functions; |
504 | |
505 | // Collect function names to be mangled and create |
506 | // global mangled variables |
507 | for (const auto& pair : glob_funcs) { |
508 | if (auto* fn = pair.second.as<FunctionNode>()) { |
509 | auto func = GetRef<Function>(fn); |
510 | if (func->GetAttr<String>(attr::kCompiler).defined()) { |
511 | auto fn_name_mangled = tvm::runtime::SanitizeName(mangle_fn_(pair.first->name_hint)); |
512 | GlobalVar gvar = GlobalVar(fn_name_mangled); |
513 | mangled_gvars_[pair.first->name_hint] = gvar; |
514 | } |
515 | } |
516 | } |
517 | |
518 | // Walk the tree and mangle the functions. Then replace compiler functions |
519 | // with mangled functions in the module |
520 | IRModule new_module = module_->ShallowCopy(); |
521 | new_module->functions = {}; |
522 | |
523 | for (const auto& pair : glob_funcs) { |
524 | if (auto* fn = pair.second.as<FunctionNode>()) { |
525 | auto func = GetRef<Function>(fn); |
526 | |
527 | if (func->GetAttr<String>(attr::kCompiler).defined()) { |
528 | auto new_dict = func->attrs->dict; |
529 | new_dict.Set(tvm::attr::kGlobalSymbol, |
530 | String(tvm::runtime::SanitizeName(mangle_fn_(pair.first->name_hint)))); |
531 | func = WithFields(func, func->params, VisitExpr(func->body), func->ret_type, |
532 | func->type_params, DictAttrs(new_dict)); |
533 | |
534 | new_module->Add(mangled_gvars_[pair.first->name_hint], func); |
535 | } else { |
536 | func = WithFields(func, func->params, VisitExpr(func->body)); |
537 | new_module->Add(pair.first, func); |
538 | } |
539 | } |
540 | } |
541 | |
542 | return new_module; |
543 | } |
544 | |
545 | private: |
546 | Expr Rewrite_(const CallNode* call, const Expr& post) final { |
547 | Expr new_expr = post; |
548 | const CallNode* new_call = new_expr.as<CallNode>(); |
549 | auto op_node = new_call->op.as<GlobalVarNode>(); |
550 | if (op_node == nullptr || mangled_gvars_.find(op_node->name_hint) == mangled_gvars_.end()) { |
551 | return new_expr; |
552 | } else { |
553 | return Call(mangled_gvars_[op_node->name_hint], new_call->args, new_call->attrs, |
554 | new_call->type_args, new_call->span); |
555 | } |
556 | } |
557 | |
558 | /*!\brief The IRModule used for partitioning. */ |
559 | IRModule module_; |
560 | /*!\brief The function used to mangle operators name */ |
561 | std::function<String(String)> mangle_fn_; |
562 | /*!\brief Tabled used to store (unmangled_var_name, mangled_gvar) pairs*/ |
563 | std::unordered_map<std::string, GlobalVar> mangled_gvars_; |
564 | }; |
565 | |
566 | } // namespace partitioning |
567 | |
568 | namespace transform { |
569 | |
570 | Pass PartitionGraph(String mod_name, bool bind_constants) { |
571 | runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> flatten_tuples = [=](IRModule m, |
572 | PassContext pc) { |
573 | // There could be compiler_end annotations on tuples |
574 | // If the corresponding region is having multiple compiler_ends, |
575 | // this would lead to creation of tuples of tuples. |
576 | // Thus, we flatten the tuples by transfering the compiler_end to |
577 | // the tuple inputs. |
578 | return partitioning::FlattenTupleOutputs(m); |
579 | }; |
580 | |
581 | runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> remove_defaults = [=](IRModule m, |
582 | PassContext pc) { |
583 | // TODO(@comaniac, @zhiics): We should also handle the annotation with "default" attribute |
584 | // by treating them as un-annotated, but we don't have it yet. This workaround pass removes |
585 | // all "default" annotations and should be deleted in the future. |
586 | return partitioning::RemoveDefaultAnnotations(m); |
587 | }; |
588 | |
589 | runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func = [=](IRModule m, |
590 | PassContext pc) { |
591 | return partitioning::Partitioner(m, bind_constants).Partition(); |
592 | }; |
593 | |
594 | auto name_mangling_fn = [mod_name](String name) { |
595 | return runtime::get_name_mangled(mod_name, name); |
596 | }; |
597 | |
598 | runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> name_mangling_func = |
599 | [=](IRModule m, PassContext pc) { |
600 | return partitioning::NameMangleExtFuncs(m, name_mangling_fn).Run(); |
601 | }; |
602 | |
603 | auto flatten_tuples_pass = CreateModulePass(flatten_tuples, 0, "FlattenNestedTuples" , {}); |
604 | auto remove_default_pass = CreateModulePass(remove_defaults, 0, "RemoveDefaultAnnotations" , {}); |
605 | auto partition_pass = CreateModulePass(part_func, 0, "PartitionGraph" , {}); |
606 | auto name_mangling_pass = CreateModulePass(name_mangling_func, 0, "NameMangleExtFuncs" , {}); |
607 | return Sequential( |
608 | {flatten_tuples_pass, remove_default_pass, partition_pass, name_mangling_pass, InferType()}); |
609 | } |
610 | |
611 | TVM_REGISTER_GLOBAL("relay._transform.PartitionGraph" ) |
612 | .set_body_typed([](String mod_name, bool bind_constants) { |
613 | return transform::PartitionGraph(mod_name, bind_constants); |
614 | }); |
615 | |
616 | } // namespace transform |
617 | |
618 | } // namespace relay |
619 | } // namespace tvm |
620 | |