1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/transforms/partition_graph.cc
22 *
23 * \brief Partition an input function into multiple functions according based
24 * on the inserted annotation nodes (i.e. compiler_begin and compiler_end).
25 * These nodes are used as boundaries to partition the Relay function into
26 * multiple regions that can be offloaded to different accelerators/backends.
27 *
28 * Each of these paritioned functions, a.k.a regions, will be viewed as
29 * external functions, and they will use the provided compiler for codegen.
30 */
31
32#include <tvm/ir/module.h>
33#include <tvm/relay/analysis.h>
34#include <tvm/relay/attrs/annotation.h>
35#include <tvm/relay/error.h>
36#include <tvm/relay/expr.h>
37#include <tvm/relay/expr_functor.h>
38#include <tvm/relay/transform.h>
39#include <tvm/runtime/name_transforms.h>
40
41#include <unordered_map>
42#include <unordered_set>
43#include <utility>
44#include <vector>
45
46#include "../analysis/annotated_region_set.h"
47#include "../backend/name_transforms.h"
48#include "../backend/utils.h"
49#include "pass_utils.h"
50
51namespace tvm {
52namespace relay {
53
54namespace partitioning {
55
56/*! \brief This struct maintains the required metadata for a region to generate a corresponding
57 * global function and function call. Global function will be passed to the target specific codegen
58 * and function call will be used in the transform Relay graph to invoke the function in runtime.
59 */
60struct RegionFuncMetadata {
61 /*! \brief The call node of the generated global function for this region. */
62 Call func_call;
63
64 /*! \brief A list of argument pairs. Each pair includes (var, expr). var is used
65 * as a function node argument; input expression is used as a function call parameter.
66 */
67 std::vector<std::pair<Var, Expr>> args;
68
69 /*! \brief Map from each region output expr (compiler end) node to
70 * the corresponding function output expr.
71 */
72 std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual> region_func_out;
73
74 /*! \brief Map from each region input expression (compiler begin) to
75 * the corresponding function input variable. This cache is used to make sure
76 * a region function will not have duplicated inputs even if it refers to
77 * the same expr multiple times.
78 */
79 std::unordered_map<Expr, Var, ObjectPtrHash, ObjectPtrEqual> region_func_in;
80};
81
82/*! \brief This class partitions the expr labeled with begin and end annotations
83 * into function containing multiple regions. Each region is labeled with
84 * a compiler attribute so that it will be handled by any compilers that are not
85 * in the TVM stack.
86 *
87 * Input : A Relay module that has functions with disjoint annotated regions
88 * using compiler_begin and compiler_end. There could be multiple
89 * outputs.
90 *
91 * Output : A Relay module with global functions for such disjoint annotated
92 * regions with calls inserted at the respective location
93 *
94 * Dependencies : AnnotatedRegionSet Utility class.
95 *
96 * Methodology :
97 * 1) The AnnotatedRegionSet utility class is able to construct a collection
98 * of nodes that are bound by a given annotation -- here we use
99 * compiler_begin and compiler_end
100 * 2) Initially, for each function in the module RegionSets are populated.
101 * 3) Then, Vistor pass is traversed until a compiler_end node is encountered
102 * that belongs to a "region".
103 * 4) When the first compiler_end of a given annotated region is found,
104 * a function is formed and inserted.
105 * a) if the region has multiple outputs, a Tuple node (capturing
106 * all outputs) is returned.
107 * 5) Thereafter, if we encounter an another output of the same annotated
108 * region, it is important to note that the function is already formed.
109 * Therefore, it will lookup the function and add a TupleGetItemNode.
110 * a) We will use the location index of "rets" of each Region" of
111 * AnnotatedRegionSet as TupleGetItemNode index.
112 * 6) Therefore, functions will be created for all annotated regions.
113 * The name for each global function is created using "Region" id and
114 * the compiler name.
115 */
116
117class Partitioner : public MixedModeMutator {
118 public:
119 Partitioner(const IRModule& module, bool bind_constants)
120 : module_(module), bind_constants_(bind_constants) {
121 std::set<std::string> func_names;
122 for (auto f : module->functions) {
123 GlobalVar f_var = f.first;
124 BaseFunc f_func = f.second;
125 std::string f_name = f_var.as<GlobalVarNode>()->name_hint;
126 while (func_names.find(f_name) != func_names.end()) {
127 f_name += "_a";
128 }
129 func_names.insert(f_name);
130
131 // Creating regionset per function in the module.
132 auto region_set =
133 AnnotatedRegionSet::Create(f_func, CompilerBeginOp(), CompilerEndOp(), f_name);
134 regions_sets_[region_set] = f_func;
135 }
136 }
137
138 Expr Rewrite_(const CallNode* call, const Expr& post) final {
139 auto op_node = call->op.as<OpNode>();
140 if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
141 return post;
142 } else if (call->op == CompilerBeginOp()) {
143 // The annotation node is inserted on edge so it must have only one argument.
144 ICHECK_EQ(call->args.size(), 1U);
145
146 // Traverse the rest graph.
147 Expr parent = call->args[0];
148 auto input_expr = Downcast<Call>(post)->args[0];
149
150 // Backtrace the parent to find the first ancestor node that is not a begin or end op
151 while (const auto* parent_call = parent.as<CallNode>()) {
152 if (parent_call->op == CompilerBeginOp() || parent_call->op == CompilerEndOp()) {
153 parent = parent_call->args[0];
154 } else {
155 break;
156 }
157 }
158
159 AnnotatedRegion sg = GetRegion(GetRef<Call>(call));
160 int index = GetArgIdx(sg, GetRef<Call>(call));
161 ICHECK_NE(index, -1);
162
163 if (region_func_meta_[sg].region_func_in.count(parent)) {
164 return region_func_meta_[sg].region_func_in[parent];
165 } else {
166 // The type of the created variable is the same as the compiler_begin
167 // node.
168 std::string target = call->attrs.as<CompilerAttrs>()->compiler;
169 std::string varname =
170 target + "_" + std::to_string(sg->GetID()) + "_i" + std::to_string(index);
171 auto var = Var(varname, GetRef<Call>(call)->checked_type_);
172
173 std::pair<Var, Expr> cand = std::make_pair(var, input_expr);
174
175 if (std::find(region_func_meta_[sg].args.begin(), region_func_meta_[sg].args.end(), cand) ==
176 region_func_meta_[sg].args.end()) {
177 region_func_meta_[sg].args.push_back(cand);
178 }
179 region_func_meta_[sg].region_func_in[parent] = var;
180 return std::move(var);
181 }
182 } else {
183 ICHECK_EQ(call->op, CompilerEndOp());
184 // The annotation node is inserted on edge so it must have only one
185 // argument.
186 ICHECK_EQ(call->args.size(), 1U);
187
188 AnnotatedRegion region = GetRegion(GetRef<Call>(call));
189
190 // TODO(@manupa-arm) : need to use the parent function (to which region
191 // belongs to) name/key for the functions that are created
192 BaseFunc f = GetFunc(GetRef<Call>(call));
193
194 // Traverse subgraph inputs.
195 auto input = Downcast<Call>(post)->args[0];
196 ICHECK(region.defined()) << "Region not defined for " << GetRef<Call>(call);
197 // functions are created for each annotated regions,
198 // when their first output is encountered.
199 // If multiple outputs are there, a tuple node is inserted at the end.
200
201 if (!region_func_meta_[region].func_call.defined()) {
202 // First time this region is encountered in the traversal. Creating the function.
203 CreateFunction(region, call);
204 }
205
206 // Retrieve this particular output of function.
207 Expr region_out_expr = Downcast<Call>(GetRef<Call>(call))->args[0];
208 ICHECK(region_func_meta_[region].region_func_out.count(region_out_expr));
209 return region_func_meta_[region].region_func_out[region_out_expr];
210 }
211 }
212
213 IRModule Partition() {
214 auto glob_funcs = module_->functions;
215 for (const auto& pair : glob_funcs) {
216 if (auto* fn = pair.second.as<FunctionNode>()) {
217 Function func = GetRef<Function>(fn);
218 func = WithFields(func, func->params, VisitExpr(func->body));
219 module_->Update(pair.first, func);
220 module_ = transform::InferType()(module_);
221 }
222 }
223 return module_;
224 }
225
226 private:
227 /*!
228 * \brief Get the region an expression belongs to
229 * if its in a region.
230 */
231 AnnotatedRegion GetRegion(const Expr& e) {
232 for (auto sg_set_it : regions_sets_) {
233 auto sg_set = sg_set_it.first;
234 AnnotatedRegion sg = sg_set->GetRegion(e);
235 if (sg.defined()) {
236 return sg;
237 }
238 }
239 return AnnotatedRegion(nullptr);
240 }
241
242 /*!
243 * \brief Get the function an expression belongs to
244 * if its in a region.
245 */
246 BaseFunc GetFunc(const Expr& e) {
247 for (auto sg_set_it : regions_sets_) {
248 auto sg_set = sg_set_it.first;
249 auto func = sg_set_it.second;
250
251 AnnotatedRegion sg = sg_set->GetRegion(e);
252 if (sg.defined()) {
253 return func;
254 }
255 }
256 return BaseFunc(nullptr);
257 }
258
259 /*!
260 * \brief Get the index of the argument;
261 * this is to be used as tuplegetitem idx
262 */
263 int GetArgIdx(AnnotatedRegion sg, const Expr& arg) {
264 int idx = 0;
265 for (auto arg_ : sg->GetInputs()) {
266 if (arg == arg_) {
267 return idx;
268 }
269 idx++;
270 }
271 return -1;
272 }
273
274 /*!
275 * \brief Check if an expr is a constant or a tuple that only contain constants.
276 */
277 bool IsConstant(const Expr& expr) const {
278 if (expr->IsInstance<ConstantNode>()) return true;
279 if (!expr->IsInstance<TupleNode>()) return false;
280 const auto* tn = expr.as<TupleNode>();
281 return std::all_of(tn->fields.begin(), tn->fields.end(),
282 [](const Expr& e) { return e->IsInstance<ConstantNode>(); });
283 }
284
285 /*!
286 * \brief Create a call to the function that represents a region.
287 * \note The customized optimization pipeline will be invoked as well to
288 * optimize each function that is handled by external codegen.
289 */
290 Call CreateRegionCall(AnnotatedRegion region, const Array<Expr>& fields,
291 const CallNode* end_node) {
292 Array<Var> params;
293 Array<Expr> param_expr;
294 Map<Var, Expr> params_bind;
295 for (auto pair : region_func_meta_[region].args) {
296 params.push_back(pair.first);
297 if (bind_constants_ && IsConstant(pair.second)) {
298 params_bind.Set(pair.first, pair.second);
299 } else {
300 param_expr.push_back(pair.second);
301 }
302 }
303
304 Function global_region_func;
305 if (fields.size() == 1) {
306 // If there are only a single output; no need to add a tuple
307 global_region_func =
308 Function(params, fields[0], end_node->args[0]->checked_type_, {}, DictAttrs());
309 } else {
310 auto tuple = Tuple(fields);
311 global_region_func = Function(params, tuple, tuple->checked_type_, {}, DictAttrs());
312 }
313
314 std::string target = end_node->attrs.as<CompilerAttrs>()->compiler;
315 std::string name = target + "_" + region->GetName() + "_" + std::to_string(region->GetID());
316
317 // Constant propagation
318 if (!params_bind.empty()) {
319 global_region_func = Downcast<Function>(relay::Bind(global_region_func, params_bind));
320 }
321 std::string ext_opt = "relay.ext." + target + ".optimize";
322 auto pf = tvm::runtime::Registry::Get(ext_opt);
323 if (pf != nullptr) {
324 auto mod = IRModule::FromExpr(global_region_func);
325 mod = transform::InferType()(mod);
326 mod = (*pf)(mod);
327 global_region_func = Downcast<Function>(mod->Lookup("main"));
328 }
329
330 global_region_func =
331 WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol, runtime::String(name));
332 global_region_func = WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1));
333 global_region_func =
334 WithAttr(std::move(global_region_func), attr::kCompiler, tvm::runtime::String(target));
335 global_region_func = WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1));
336
337 GlobalVarSupply global_var_supply = GlobalVarSupply(module_);
338 GlobalVar glob_func = global_var_supply->FreshGlobal(name, false);
339 ICHECK(!module_->ContainGlobalVar(glob_func->name_hint))
340 << "Global function " << glob_func->name_hint << " already exists";
341 // Create a global function and add it to the IRModule for the region.
342 // This way we lift the functions that should be handled by external
343 // codegen to the module scope and rely on the pass manager to prevent
344 // relay function level passes (i.e. simplify inference and fusion)
345 // optimizing it.
346 module_->Add(glob_func, global_region_func);
347 module_ = relay::transform::InferType()(module_);
348
349 // Create a call node for the function.
350 auto call = Call(glob_func, param_expr);
351 region_func_meta_[region].func_call = call;
352
353 return call;
354 }
355
356 /*!
357 * \brief Create a function and its function call for the given region. If the function has
358 * multiple outputs, a Tuple will be formed to aggregate all outputs, and TupleGetItem nodes
359 * will be created to serve output consumers.
360 */
361 void CreateFunction(AnnotatedRegion region, const CallNode* end_node) {
362 // Create fields which is a unique list of outputs.
363 Array<Expr> fields;
364 std::unordered_map<Expr, int, ObjectPtrHash, ObjectPtrEqual> out_expr_to_idx;
365 int out_idx = 0;
366 for (auto region_end_node : region->GetOutputs()) {
367 auto ret_node = Downcast<Call>(region_end_node)->args[0];
368 // Don't duplicate outputs.
369 if (!out_expr_to_idx.count(ret_node)) {
370 auto ret_expr = MixedModeMutator::VisitExpr(ret_node);
371 fields.push_back(ret_expr);
372 out_expr_to_idx[ret_node] = out_idx++;
373 }
374 }
375
376 Call call = CreateRegionCall(region, fields, end_node);
377
378 // Create output expr(s) for the function call.
379 if (out_expr_to_idx.size() == 1) {
380 // Single output direcly uses the call node as the output expr.
381 region_func_meta_[region].region_func_out[out_expr_to_idx.begin()->first] = call;
382 } else {
383 // Multiple outptus need to create TupleGetItem nodes as output exprs.
384 for (auto pair : out_expr_to_idx) {
385 Expr region_out_expr = pair.first; // The arg of a compiler end node of this region.
386 int idx = pair.second; // Corresponding function output tuple index.
387 auto tuple_get_item = TupleGetItem(call, idx);
388 tuple_get_item->checked_type_ = region_out_expr->checked_type_;
389 region_func_meta_[region].region_func_out[region_out_expr] = tuple_get_item;
390 }
391 }
392 }
393
394 /*! \brief Map from each region to its metadata of the generated function. */
395 std::unordered_map<AnnotatedRegion, RegionFuncMetadata, ObjectPtrHash, ObjectPtrEqual>
396 region_func_meta_;
397
398 /*! \brief Each region set is associated with a function in the module.
399 * This map maintains the mapping between regionsets and the function it
400 * belongs to
401 */
402 std::unordered_map<AnnotatedRegionSet, BaseFunc, ObjectPtrHash, ObjectPtrEqual> regions_sets_;
403
404 /*!\brief The IRModule used for partitioning. */
405 IRModule module_;
406
407 /*!\brief Whether or not to bind constants in partitioned subgraphs. */
408 bool bind_constants_{false};
409};
410
411IRModule RemoveDefaultAnnotations(IRModule module) {
412 class DefaultRemover : public ExprRewriter {
413 public:
414 DefaultRemover() = default;
415
416 Expr Rewrite_(const CallNode* call, const Expr& post) final {
417 auto attrs = call->attrs.as<CompilerAttrs>();
418 if (attrs != nullptr && attrs->compiler == "default") {
419 return Downcast<Call>(post)->args[0];
420 }
421 return post;
422 }
423 };
424
425 auto glob_funcs = module->functions;
426 // module is mutable, hence, we make a copy of it.
427 module.CopyOnWrite();
428 for (const auto& pair : glob_funcs) {
429 if (auto* fn = pair.second.as<FunctionNode>()) {
430 auto func = GetRef<Function>(fn);
431 DefaultRemover remover;
432 auto removed = PostOrderRewrite(func->body, &remover);
433 func = WithFields(func, func->params, removed);
434 module->Update(pair.first, func);
435 module = relay::transform::InferType()(module);
436 }
437 }
438 return module;
439}
440
441/*! \brief There can be regions with multiple outputs where each output
442 * could be a tuple output. Such tuple outputs needs to be flattened
443 * otherwise the function would create tuples of tuples. Moreover, tuple
444 * of tuples are valid relay, however they are not currently supported by
445 * graph executor or relay VM.
446 */
447
448// New annotations would be required to be added for each flattened output
449static const PackedFunc* make_end_op =
450 runtime::Registry::Get("relay.op.annotation._make.compiler_end");
451
452IRModule FlattenTupleOutputs(IRModule module) {
453 class TupleOutFlattener : public ExprRewriter {
454 public:
455 TupleOutFlattener() = default;
456
457 Expr Rewrite_(const CallNode* call, const Expr& post) final {
458 if (call->op == CompilerEndOp()) {
459 std::string target = call->attrs.as<CompilerAttrs>()->compiler;
460 // Arguments of annotation ops should be 1
461 ICHECK_EQ(call->args.size(), 1U);
462 auto annotated_op = Downcast<Call>(post)->args[0];
463 if (const auto* tuple_node = annotated_op.as<TupleNode>()) {
464 Array<Expr> new_fields;
465 new_fields.reserve(tuple_node->fields.size());
466
467 // Here each input of the tuple will be annotated with compiler_ends
468 for (auto& tn_arg : tuple_node->fields) {
469 new_fields.push_back((*make_end_op)(tn_arg, target));
470 }
471
472 // Return a tuple of compiler_ends in the place of the tuple that was
473 // annotated with a compiler_end.
474 return WithFields(GetRef<Tuple>(tuple_node), new_fields);
475 }
476 }
477 return post;
478 }
479 };
480
481 auto glob_funcs = module->functions;
482 // module is mutable, hence, we make a copy of it.
483 module.CopyOnWrite();
484 for (const auto& pair : glob_funcs) {
485 if (auto* fn = pair.second.as<FunctionNode>()) {
486 Function func = GetRef<Function>(fn);
487 TupleOutFlattener to_flattener;
488 auto removed = PostOrderRewrite(func->body, &to_flattener);
489 func = WithFields(func, func->params, removed);
490 module->Update(pair.first, func);
491 module = relay::transform::InferType()(module);
492 }
493 }
494 return module;
495}
496
497class NameMangleExtFuncs : public MixedModeMutator {
498 public:
499 explicit NameMangleExtFuncs(const IRModule& module, std::function<String(String)> mangle_fn)
500 : module_(module), mangle_fn_(mangle_fn) {}
501
502 IRModule Run() {
503 auto glob_funcs = module_->functions;
504
505 // Collect function names to be mangled and create
506 // global mangled variables
507 for (const auto& pair : glob_funcs) {
508 if (auto* fn = pair.second.as<FunctionNode>()) {
509 auto func = GetRef<Function>(fn);
510 if (func->GetAttr<String>(attr::kCompiler).defined()) {
511 auto fn_name_mangled = tvm::runtime::SanitizeName(mangle_fn_(pair.first->name_hint));
512 GlobalVar gvar = GlobalVar(fn_name_mangled);
513 mangled_gvars_[pair.first->name_hint] = gvar;
514 }
515 }
516 }
517
518 // Walk the tree and mangle the functions. Then replace compiler functions
519 // with mangled functions in the module
520 IRModule new_module = module_->ShallowCopy();
521 new_module->functions = {};
522
523 for (const auto& pair : glob_funcs) {
524 if (auto* fn = pair.second.as<FunctionNode>()) {
525 auto func = GetRef<Function>(fn);
526
527 if (func->GetAttr<String>(attr::kCompiler).defined()) {
528 auto new_dict = func->attrs->dict;
529 new_dict.Set(tvm::attr::kGlobalSymbol,
530 String(tvm::runtime::SanitizeName(mangle_fn_(pair.first->name_hint))));
531 func = WithFields(func, func->params, VisitExpr(func->body), func->ret_type,
532 func->type_params, DictAttrs(new_dict));
533
534 new_module->Add(mangled_gvars_[pair.first->name_hint], func);
535 } else {
536 func = WithFields(func, func->params, VisitExpr(func->body));
537 new_module->Add(pair.first, func);
538 }
539 }
540 }
541
542 return new_module;
543 }
544
545 private:
546 Expr Rewrite_(const CallNode* call, const Expr& post) final {
547 Expr new_expr = post;
548 const CallNode* new_call = new_expr.as<CallNode>();
549 auto op_node = new_call->op.as<GlobalVarNode>();
550 if (op_node == nullptr || mangled_gvars_.find(op_node->name_hint) == mangled_gvars_.end()) {
551 return new_expr;
552 } else {
553 return Call(mangled_gvars_[op_node->name_hint], new_call->args, new_call->attrs,
554 new_call->type_args, new_call->span);
555 }
556 }
557
558 /*!\brief The IRModule used for partitioning. */
559 IRModule module_;
560 /*!\brief The function used to mangle operators name */
561 std::function<String(String)> mangle_fn_;
562 /*!\brief Tabled used to store (unmangled_var_name, mangled_gvar) pairs*/
563 std::unordered_map<std::string, GlobalVar> mangled_gvars_;
564};
565
566} // namespace partitioning
567
568namespace transform {
569
570Pass PartitionGraph(String mod_name, bool bind_constants) {
571 runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> flatten_tuples = [=](IRModule m,
572 PassContext pc) {
573 // There could be compiler_end annotations on tuples
574 // If the corresponding region is having multiple compiler_ends,
575 // this would lead to creation of tuples of tuples.
576 // Thus, we flatten the tuples by transfering the compiler_end to
577 // the tuple inputs.
578 return partitioning::FlattenTupleOutputs(m);
579 };
580
581 runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> remove_defaults = [=](IRModule m,
582 PassContext pc) {
583 // TODO(@comaniac, @zhiics): We should also handle the annotation with "default" attribute
584 // by treating them as un-annotated, but we don't have it yet. This workaround pass removes
585 // all "default" annotations and should be deleted in the future.
586 return partitioning::RemoveDefaultAnnotations(m);
587 };
588
589 runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func = [=](IRModule m,
590 PassContext pc) {
591 return partitioning::Partitioner(m, bind_constants).Partition();
592 };
593
594 auto name_mangling_fn = [mod_name](String name) {
595 return runtime::get_name_mangled(mod_name, name);
596 };
597
598 runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> name_mangling_func =
599 [=](IRModule m, PassContext pc) {
600 return partitioning::NameMangleExtFuncs(m, name_mangling_fn).Run();
601 };
602
603 auto flatten_tuples_pass = CreateModulePass(flatten_tuples, 0, "FlattenNestedTuples", {});
604 auto remove_default_pass = CreateModulePass(remove_defaults, 0, "RemoveDefaultAnnotations", {});
605 auto partition_pass = CreateModulePass(part_func, 0, "PartitionGraph", {});
606 auto name_mangling_pass = CreateModulePass(name_mangling_func, 0, "NameMangleExtFuncs", {});
607 return Sequential(
608 {flatten_tuples_pass, remove_default_pass, partition_pass, name_mangling_pass, InferType()});
609}
610
611TVM_REGISTER_GLOBAL("relay._transform.PartitionGraph")
612 .set_body_typed([](String mod_name, bool bind_constants) {
613 return transform::PartitionGraph(mod_name, bind_constants);
614 });
615
616} // namespace transform
617
618} // namespace relay
619} // namespace tvm
620