1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * |
22 | * \file src/relay/transforms/fuse_ops.cc |
23 | * |
24 | * \brief This is a backend-aware optimization pass. |
25 | * Fuse necessary ops into a single one. |
26 | */ |
27 | #include <tvm/relay/analysis.h> |
28 | #include <tvm/relay/executor.h> |
29 | #include <tvm/relay/expr_functor.h> |
30 | #include <tvm/relay/op_attr_types.h> |
31 | #include <tvm/relay/transform.h> |
32 | #include <tvm/tir/op.h> |
33 | |
34 | #include "../../support/arena.h" |
35 | #include "../analysis/graph_partitioner.h" |
36 | #include "../op/annotation/annotation.h" |
37 | #include "./pass_utils.h" |
38 | #include "./pattern_utils.h" |
39 | |
40 | namespace tvm { |
41 | namespace relay { |
42 | |
43 | /* |
44 | Note on Fusing algorithm: |
45 | |
46 | The main challenge of general fusor is to handle possible diamond shape branches, |
47 | in the following graph, conv2d can be fused to elemwise add. |
48 | |
49 | conv2d |
50 | / | \ |
51 | / | \ |
52 | op op op |
53 | \ | / |
54 | \ | / |
55 | elemwise add |
56 | | |
57 | |
58 | However, at the point of conv2d we do not necessarily know that all the future paths |
59 | will merge at the elemwise add. The fusion algorithm applies post-dominator analysis. |
60 | |
61 | The immediate post-dominator of a node defined by the closest node where all the future path goes |
62 | into. In the above case, the elemwise add is the post-dominator of conv2d. The general algorithm |
63 | is as follows: |
64 | |
65 | - Construct a DAG of dataflow graph for dominator analysis |
66 | - Construct a post-dominator tree which gives immediate post dominator of each node. |
67 | - Run fusion algorithm with the given post-dominator information. |
68 | |
69 | Note that, because we run analysis on a DAG, we use a single pass post-dominator |
70 | tree construction algorithm via LCA, which is simpler than the full version that handles cycles. |
71 | |
72 | The fusion algorithm traverses from each node and checks if it can be fused to its |
73 | immediate post dominator. It has to check the following things: |
74 | |
75 | - CheckPath: check all the path between a node and its immediate post-dominator |
76 | satisfies the fuse condition. |
77 | - Note that these intermediate node can already be fused with another nodes, the algorithm |
78 | will still run correctly. |
79 | - CommitFuse: mark all the nodes between source and post-dominator as the same group. |
80 | - We use an Union-Find data structure to manage the groups. |
81 | */ |
82 | using support::LinkedList; |
83 | using support::LinkNode; |
84 | |
85 | constexpr uint32_t kMaxFusedOps = 256; |
86 | |
87 | static const Op& stop_fusion_op = Op::Get("annotation.stop_fusion" ); |
88 | |
89 | TVM_REGISTER_PASS_CONFIG_OPTION("relay.FuseOps.max_depth" , Integer); |
90 | TVM_REGISTER_PASS_CONFIG_OPTION("relay.FuseOps.link_params" , Bool); |
91 | |
92 | // Creator of post dominator tree of the dataflow |
93 | class IndexedForwardGraphCreator : private ExprVisitor { |
94 | public: |
95 | static IndexedForwardGraph Create(support::Arena* arena, const Expr& body) { |
96 | IndexedForwardGraphCreator creator(arena); |
97 | return creator.Prepare(body); |
98 | } |
99 | |
100 | private: |
101 | explicit IndexedForwardGraphCreator(support::Arena* arena) : arena_(arena) {} |
102 | |
103 | IndexedForwardGraph Prepare(const Expr& body) { |
104 | this->Update(body, nullptr, kOpaque); |
105 | this->VisitExpr(body); |
106 | return std::move(graph_); |
107 | } |
108 | |
109 | private: |
110 | /*! \brief allocator of all the internal node object */ |
111 | support::Arena* arena_; |
112 | // The output. |
113 | IndexedForwardGraph graph_; |
114 | // attribute equal comparator |
115 | StructuralEqual attr_equal_; |
116 | // Update the message stored at the node. |
117 | void Update(const Expr& node, IndexedForwardGraph::Node* parent, OpPatternKind pattern) { |
118 | const tvm::Object* key = node.get(); |
119 | IndexedForwardGraph::Node* current; |
120 | auto it = graph_.node_map.find(key); |
121 | if (it != graph_.node_map.end()) { |
122 | current = it->second; |
123 | } else { |
124 | current = arena_->make<IndexedForwardGraph::Node>(); |
125 | graph_.node_map[key] = current; |
126 | } |
127 | if (parent != nullptr) { |
128 | auto* link = arena_->make<LinkNode<IndexedForwardGraph::Edge>>(); |
129 | link->value.node = parent; |
130 | link->value.pattern = pattern; |
131 | current->outputs.Push(link); |
132 | } else { |
133 | current->extern_ref = true; |
134 | } |
135 | } |
136 | |
137 | void AddNode(const tvm::Object* key) { |
138 | auto it = graph_.node_map.find(key); |
139 | ICHECK(it != graph_.node_map.end()) << "Cannot find node " << GetRef<ObjectRef>(key); |
140 | IndexedForwardGraph::Node* node = it->second; |
141 | ICHECK(node->ref == nullptr); |
142 | node->ref = key; |
143 | node->index = graph_.post_dfs_order.size(); |
144 | graph_.post_dfs_order.push_back(node); |
145 | } |
146 | |
147 | // Post order tree |
148 | void VisitExpr_(const FunctionNode* op) final { |
149 | // Skip the function that should be handled by external codegen. |
150 | if (op->GetAttr<String>(attr::kCompiler).defined()) return; |
151 | |
152 | for (auto param : op->params) { |
153 | this->Update(param, nullptr, kOpaque); |
154 | } |
155 | this->Update(op->body, nullptr, kOpaque); |
156 | ExprVisitor::VisitExpr_(op); |
157 | } |
158 | |
159 | void VisitExpr_(const ConstantNode* op) final { |
160 | this->AddNode(op); |
161 | IndexedForwardGraph::Node* node = graph_.node_map.at(op); |
162 | DataType dtype = DataType(op->data->dtype); |
163 | // This rule must be consistent with code generator. |
164 | bool is_simple_const = |
165 | (dtype == DataType::Int(32) || dtype == DataType::Int(64) || dtype == DataType::Float(32) || |
166 | dtype == DataType::Float(64) || dtype == DataType::Bool()); |
167 | if (op->is_scalar() && is_simple_const) { |
168 | node->pattern = kElemWise; |
169 | } else { |
170 | // for now, mark non-scalar constant |
171 | // as opaque, we will not choose to fuse it. |
172 | node->pattern = kOpaque; |
173 | } |
174 | } |
175 | |
176 | void VisitExpr_(const CallNode* call) final { |
177 | ICHECK(graph_.node_map.count(call)); |
178 | IndexedForwardGraph::Node* node = graph_.node_map.at(call); |
179 | static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern" ); |
180 | // Now we set the pattern of this call. |
181 | // |
182 | // If we see a call mentioning an operator we should mark it with its |
183 | // annotated pattern. |
184 | // |
185 | // If the pattern is not annotated we will default to opaque. |
186 | // |
187 | // Finally if the operator position is not a call node we will |
188 | // need to call Update, as it may be an arbitrary expression. |
189 | OpPatternKind op_pattern = kOpaque; |
190 | if (const OpNode* opnode = call->op.as<OpNode>()) { |
191 | auto op = GetRef<Op>(opnode); |
192 | if (IsDynamic(call->checked_type()) && IsDataDependent(call)) { |
193 | // output of a shape func can't be fed to a data-dependent shape func |
194 | op_pattern = kOpaque; |
195 | } else { |
196 | op_pattern = static_cast<OpPatternKind>(fpattern[op]); |
197 | } |
198 | } else { |
199 | this->Update(call->op, node, kOpaque); |
200 | } |
201 | |
202 | node->pattern = op_pattern; |
203 | this->Update(call->op, nullptr, kOpaque); |
204 | const auto* rtype = call->checked_type().as<TensorTypeNode>(); |
205 | // pass the analysis back to all the children it references. |
206 | for (size_t i = 0; i < call->args.size(); ++i) { |
207 | const auto* arg_type = call->args[i]->checked_type().as<TensorTypeNode>(); |
208 | // specifically check if result type is the same as arguments type |
209 | OpPatternKind edge_pattern = op_pattern; |
210 | if (edge_pattern == kBroadcast && arg_type != nullptr && rtype != nullptr && |
211 | attr_equal_(rtype->shape, arg_type->shape)) { |
212 | edge_pattern = kElemWise; |
213 | } |
214 | this->Update(call->args[i], node, edge_pattern); |
215 | } |
216 | ExprVisitor::VisitExpr_(call); |
217 | this->AddNode(call); |
218 | } |
219 | |
220 | void VisitExpr_(const TupleNode* op) final { |
221 | ICHECK(graph_.node_map.count(op)); |
222 | IndexedForwardGraph::Node* tuple_node = graph_.node_map.at(op); |
223 | tuple_node->pattern = kTuple; |
224 | for (const Expr& field : op->fields) { |
225 | if (field->checked_type().as<TensorTypeNode>()) { |
226 | this->Update(field, tuple_node, kInjective); |
227 | } else { |
228 | this->Update(field, nullptr, kOpaque); |
229 | } |
230 | } |
231 | ExprVisitor::VisitExpr_(op); |
232 | this->AddNode(op); |
233 | } |
234 | |
235 | void VisitExpr_(const TupleGetItemNode* op) final { |
236 | auto tuple_type = op->tuple->checked_type().as<TupleTypeNode>(); |
237 | ICHECK(tuple_type); |
238 | // When TVM lowers a fused function, it expects all arguments to be a Tensor or |
239 | // a tuple containing only Tensors. But this tuple may contain a reference or |
240 | // another tuple. To avoid modifying codegen logic, we do not allow fusing through this node |
241 | // if the tuple contains such non Tensor fields. However, all fields will be recursively |
242 | // visited via call to ExprVisitor::VisitExpr_(op) below and corresponding visitor methods. |
243 | bool has_non_tensor = false; |
244 | for (auto ty : tuple_type->fields) { |
245 | if (!ty.as<TensorTypeNode>()) { |
246 | has_non_tensor = true; |
247 | break; |
248 | } |
249 | } |
250 | if (has_non_tensor) { |
251 | this->Update(op->tuple, nullptr, kOpaque); |
252 | } else { |
253 | ICHECK(graph_.node_map.count(op)); |
254 | IndexedForwardGraph::Node* node = graph_.node_map.at(op); |
255 | node->pattern = kInjective; |
256 | this->Update(op->tuple, node, kInjective); |
257 | } |
258 | ExprVisitor::VisitExpr_(op); |
259 | this->AddNode(op); |
260 | } |
261 | |
262 | void VisitExpr_(const VarNode* op) final { this->AddNode(op); } |
263 | |
264 | void VisitExpr_(const LetNode* op) final { |
265 | // do not fuse through let. |
266 | auto pre_visit = [this](const LetNode* op) { |
267 | // Rely on the Memoizer to cache pre-visit values |
268 | this->Update(op->var, nullptr, kOpaque); |
269 | this->Update(op->value, nullptr, kOpaque); |
270 | this->Update(op->body, nullptr, kOpaque); |
271 | this->VisitExpr(op->var); |
272 | this->VisitExpr(op->value); |
273 | }; |
274 | auto post_visit = [this](const LetNode* op) { |
275 | this->VisitExpr(op->body); |
276 | this->visit_counter_[op] += 1; |
277 | this->AddNode(op); |
278 | }; |
279 | ExpandANormalForm(op, pre_visit, post_visit); |
280 | } |
281 | |
282 | void VisitExpr_(const IfNode* op) final { |
283 | // do not fuse through if. |
284 | this->Update(op->cond, nullptr, kOpaque); |
285 | this->Update(op->true_branch, nullptr, kOpaque); |
286 | this->Update(op->false_branch, nullptr, kOpaque); |
287 | ExprVisitor::VisitExpr_(op); |
288 | this->AddNode(op); |
289 | } |
290 | |
291 | void VisitExpr_(const RefCreateNode* op) final { |
292 | this->Update(op->value, nullptr, kOpaque); |
293 | ExprVisitor::VisitExpr_(op); |
294 | this->AddNode(op); |
295 | } |
296 | |
297 | void VisitExpr_(const RefReadNode* op) final { |
298 | this->Update(op->ref, nullptr, kOpaque); |
299 | ExprVisitor::VisitExpr_(op); |
300 | this->AddNode(op); |
301 | } |
302 | |
303 | void VisitExpr_(const RefWriteNode* op) final { |
304 | this->Update(op->ref, nullptr, kOpaque); |
305 | this->Update(op->value, nullptr, kOpaque); |
306 | ExprVisitor::VisitExpr_(op); |
307 | this->AddNode(op); |
308 | } |
309 | |
310 | void VisitExpr_(const MatchNode* op) final { |
311 | this->Update(op->data, nullptr, kOpaque); |
312 | for (const Clause& c : op->clauses) { |
313 | this->Update(c->rhs, nullptr, kOpaque); |
314 | } |
315 | ExprVisitor::VisitExpr_(op); |
316 | this->AddNode(op); |
317 | } |
318 | }; |
319 | |
320 | class FuseMutator : private MixedModeMutator { |
321 | public: |
322 | FuseMutator(int fuse_opt_level, size_t max_fuse_depth, bool link_params) |
323 | : fuse_opt_level_(fuse_opt_level), |
324 | max_fuse_depth_(max_fuse_depth), |
325 | link_params_(link_params) {} |
326 | |
327 | // Run the transform |
328 | Expr Transform(const Expr& body) { |
329 | return Transform(body, fuse_opt_level_, max_fuse_depth_, link_params_); |
330 | } |
331 | |
332 | protected: |
333 | // Run the transform |
334 | Expr Transform(const Expr& body, int fuse_opt_level, size_t max_fuse_depth, bool link_params) { |
335 | // setup the group map. |
336 | auto graph = IndexedForwardGraphCreator::Create(&arena_, body); |
337 | auto groups = GraphPartitioner(&arena_, fuse_opt_level, max_fuse_depth).Partition(graph); |
338 | for (size_t nid = 0; nid < graph.post_dfs_order.size(); ++nid) { |
339 | ICHECK(graph.post_dfs_order[nid]->ref != nullptr); |
340 | gmap_[graph.post_dfs_order[nid]->ref] = groups[nid]; |
341 | } |
342 | // The following line can be used for debug. |
343 | // this->DebugDumpGroup(body); |
344 | return this->Mutate(body); |
345 | } |
346 | |
347 | private: |
348 | int fuse_opt_level_; |
349 | size_t max_fuse_depth_; |
350 | bool link_params_; |
351 | |
352 | using MixedModeMutator::VisitExpr_; |
353 | |
354 | /*! \brief Temporary information from each group. */ |
355 | struct GroupInfo { |
356 | public: |
357 | // The parameters of the function. |
358 | Array<Var> params; |
359 | // The arguments to call the functions. |
360 | Array<Expr> arguments; |
361 | // Get a new parameter or allocate an old one |
362 | Var GetOrAllocParam(const Expr& expr, const Type& type) { |
363 | // run linear scan as most fused groups contain only a few inputs. |
364 | for (size_t i = 0; i < arguments.size(); ++i) { |
365 | if (expr.same_as(arguments[i])) return params[i]; |
366 | } |
367 | // create a new parameter. |
368 | std::ostringstream os; |
369 | os << "p" << params.size(); |
370 | auto var = Var(os.str(), type); |
371 | params.push_back(var); |
372 | arguments.push_back(expr); |
373 | return var; |
374 | } |
375 | }; |
376 | /*! \brief Internal arena. */ |
377 | support::Arena arena_; |
378 | /*! \brief The group assignment map. */ |
379 | std::unordered_map<const Object*, GraphPartitioner::Group*> gmap_; |
380 | /* \brief Internal group information map. */ |
381 | std::unordered_map<GraphPartitioner::Group*, GroupInfo> ginfo_; |
382 | |
383 | // Skip primitive function. |
384 | Expr VisitExpr_(const FunctionNode* fn_node) { |
385 | if (fn_node->HasNonzeroAttr(attr::kPrimitive)) { |
386 | return GetRef<Expr>(fn_node); |
387 | } else { |
388 | return ExprMutator::VisitExpr_(fn_node); |
389 | } |
390 | } |
391 | |
392 | // Transform calls. |
393 | Expr Rewrite_(const CallNode* call, const Expr& post) { |
394 | if (call->op.as<OpNode>()) { |
395 | static auto fnoncomputational = Op::GetAttrMap<TNonComputational>("TNonComputational" ); |
396 | static auto fqnncanonicalize = Op::GetAttrMap<FTVMLegalize>("FTVMQnnCanonicalize" ); |
397 | |
398 | Op op = Downcast<Op>(call->op); |
399 | if (fnoncomputational.get(op, false) && !fqnncanonicalize.count(op)) { |
400 | return ExprMutator::VisitExpr_(call); |
401 | } |
402 | |
403 | // If it is a primitive op call |
404 | // then we must have a group assignment for it already. |
405 | ICHECK(gmap_.count(call)); |
406 | if (call->op == stop_fusion_op) { |
407 | return ExprMutator::VisitExpr(call->args[0]); |
408 | } |
409 | auto* ret_group = gmap_.at(call)->FindRoot(); |
410 | Array<Expr> new_args = GetNewArguments(call->args, ret_group); |
411 | |
412 | auto new_call = Call(call->op, new_args, call->attrs, call->type_args, call->span); |
413 | |
414 | if (ret_group->root_ref == call) { |
415 | // This is the root of the group |
416 | // create the new call node. |
417 | return MakeNewFunction(ret_group, call->checked_type(), new_call); |
418 | } else { |
419 | // This is an intermediate node of a fused function |
420 | // simply return the new call. |
421 | return std::move(new_call); |
422 | } |
423 | } else { |
424 | return ExprMutator::VisitExpr_(call); |
425 | } |
426 | } |
427 | |
428 | Expr Rewrite_(const TupleNode* tuple_node, const Expr& post) { |
429 | auto* ret_group = gmap_.at(tuple_node)->FindRoot(); |
430 | if (ret_group->root_ref == tuple_node) { |
431 | return ExprMutator::VisitExpr_(tuple_node); |
432 | } |
433 | // This tuple is an intermediate node in the group |
434 | Array<Expr> new_fields = GetNewArguments(tuple_node->fields, ret_group); |
435 | return WithFields(GetRef<Tuple>(tuple_node), new_fields); |
436 | } |
437 | |
438 | Expr Rewrite_(const TupleGetItemNode* tuple_get, const Expr& post) { |
439 | auto* ret_group = gmap_.at(tuple_get)->FindRoot(); |
440 | auto new_tuple = GetNewArguments({tuple_get->tuple}, ret_group)[0]; |
441 | auto new_node = TupleGetItem(new_tuple, tuple_get->index); |
442 | if (ret_group->root_ref == tuple_get) { |
443 | if (gmap_.at(tuple_get->tuple.get())->FindRoot() != ret_group) { |
444 | // Isolated. This case occurs when tuple is created by an Opaque op |
445 | // e.g. multibox_transform_loc |
446 | return ExprMutator::VisitExpr_(tuple_get); |
447 | } |
448 | // A new function whose output is a tuple field access |
449 | return MakeNewFunction(ret_group, tuple_get->checked_type(), new_node); |
450 | } |
451 | // This is an intermediate node in the group |
452 | return std::move(new_node); |
453 | } |
454 | |
455 | Expr VisitExpr_(const LetNode* op) final { |
456 | auto pre_visit = [this](const LetNode* op) { |
457 | // Rely on the Memoizer to cache pre-visit values |
458 | this->VisitExpr(op->var); |
459 | this->VisitExpr(op->value); |
460 | }; |
461 | auto post_visit = [this](const LetNode* op) { |
462 | // Rely on the Memoizer to cache pre-visit values |
463 | Var var = Downcast<Var>(this->VisitExpr(op->var)); |
464 | Expr value = this->VisitExpr(op->value); |
465 | // Visit body and cache the op |
466 | Expr body = this->VisitExpr(op->body); |
467 | auto expr = GetRef<Expr>(op); |
468 | if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { |
469 | this->memo_[expr] = expr; |
470 | } else { |
471 | this->memo_[expr] = Let(var, value, body); |
472 | } |
473 | }; |
474 | ExpandANormalForm(op, pre_visit, post_visit); |
475 | return memo_[GetRef<Expr>(op)]; |
476 | } |
477 | |
478 | Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) { |
479 | // Quickly check special properties of the fused function. |
480 | // A pass to check if the fused op contains only reshape ops. |
481 | class CheckReshapeOnly : public ExprVisitor { |
482 | public: |
483 | void VisitExpr_(const CallNode* cn) final { |
484 | this->has_call = true; |
485 | static auto freshape_op = Op::GetAttrMap<TReshapeOp>("TReshapeOp" ); |
486 | |
487 | if (!freshape_op.get(cn->op, false)) { |
488 | this->reshape_only = false; |
489 | } |
490 | |
491 | if (!this->reshape_only) return; |
492 | ExprVisitor::VisitExpr_(cn); |
493 | } |
494 | |
495 | void VisitExpr_(const VarNode* vn) final { |
496 | if (!vn->type_annotation.defined() || !vn->type_annotation->IsInstance<TensorTypeNode>()) { |
497 | this->reshape_only = false; |
498 | } |
499 | } |
500 | |
501 | bool reshape_only = true; |
502 | bool has_call = false; |
503 | } visitor; |
504 | |
505 | visitor(body); |
506 | const GroupInfo& ginfo = ginfo_[group]; |
507 | auto func = Function(ginfo.params, body, ret_type, {}); |
508 | func = WithAttr(std::move(func), attr::kPrimitive, tvm::Integer(visitor.has_call)); |
509 | // TODO(mbs): "reshape" cleanup. |
510 | if (visitor.has_call && visitor.reshape_only) { |
511 | func = WithAttr(std::move(func), attr::kReshapeOnly, tvm::Integer(visitor.reshape_only)); |
512 | } |
513 | return Call(func, ginfo.arguments, Attrs()); |
514 | } |
515 | |
516 | Array<Expr> GetNewArguments(const tvm::Array<Expr>& args, |
517 | GraphPartitioner::Group* current_group) { |
518 | Array<Expr> new_args; |
519 | for (auto arg : args) { |
520 | auto* arg_group = gmap_.at(arg.get())->FindRoot(); |
521 | auto type = arg->checked_type(); |
522 | Expr new_arg = this->Mutate(arg); |
523 | if (current_group != arg_group) { |
524 | if (!link_params_ || new_arg.as<ConstantNode>() == nullptr) { |
525 | Var param = ginfo_[current_group].GetOrAllocParam(new_arg, type); |
526 | new_args.push_back(param); |
527 | } else { |
528 | new_args.push_back(new_arg); |
529 | } |
530 | } else { |
531 | new_args.push_back(new_arg); |
532 | } |
533 | } |
534 | return new_args; |
535 | } |
536 | |
537 | // Debug function, dump the group assignment in text. |
538 | void DebugDumpGroup(const Expr& body) { |
539 | std::string text = AsText(body, false, [this](const ObjectRef& expr) -> std::string { |
540 | auto it = gmap_.find(expr.get()); |
541 | if (it == gmap_.end()) return "" ; |
542 | std::ostringstream os; |
543 | auto* group = it->second->FindRoot(); |
544 | os << " /* group=" << group << " */" ; |
545 | return os.str(); |
546 | }); |
547 | LOG(INFO) << "Dump of group info:\n" << text; |
548 | } |
549 | }; |
550 | |
551 | Expr FuseOps(const Expr& expr, int fuse_opt_level, size_t max_fuse_depth, bool link_params, |
552 | const IRModule& module) { |
553 | return FuseMutator(fuse_opt_level, max_fuse_depth, link_params).Transform(expr); |
554 | } |
555 | |
556 | namespace transform { |
557 | |
558 | Pass FuseOps(int fuse_opt_level) { |
559 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
560 | [=](Function f, IRModule m, PassContext pc) { |
561 | bool link_params = false; |
562 | Executor executor = |
563 | m->GetAttr<Executor>(tvm::attr::kExecutor).value_or(NullValue<Executor>()); |
564 | link_params = executor.defined() |
565 | ? executor->attrs.GetAttr<Bool>("link-params" ).value_or(Bool(link_params)) |
566 | : link_params; |
567 | link_params = pc->GetConfig("relay.FuseOps.link_params" , Bool(link_params)).value(); |
568 | int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level; |
569 | auto max_fuse_depth = pc->GetConfig("relay.FuseOps.max_depth" , Integer(kMaxFusedOps)); |
570 | return Downcast<Function>( |
571 | FuseOps(f, opt_level, max_fuse_depth.value().IntValue(), link_params, m)); |
572 | }; |
573 | return CreateFunctionPass(pass_func, 0, "FuseOps" , {"InferType" }); |
574 | } |
575 | |
576 | TVM_REGISTER_GLOBAL("relay._transform.FuseOps" ).set_body_typed(FuseOps); |
577 | |
578 | } // namespace transform |
579 | |
580 | } // namespace relay |
581 | } // namespace tvm |
582 | |