1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 *
22 * \file src/relay/transforms/fuse_ops.cc
23 *
24 * \brief This is a backend-aware optimization pass.
25 * Fuse necessary ops into a single one.
26 */
27#include <tvm/relay/analysis.h>
28#include <tvm/relay/executor.h>
29#include <tvm/relay/expr_functor.h>
30#include <tvm/relay/op_attr_types.h>
31#include <tvm/relay/transform.h>
32#include <tvm/tir/op.h>
33
34#include "../../support/arena.h"
35#include "../analysis/graph_partitioner.h"
36#include "../op/annotation/annotation.h"
37#include "./pass_utils.h"
38#include "./pattern_utils.h"
39
40namespace tvm {
41namespace relay {
42
43/*
44 Note on Fusing algorithm:
45
46 The main challenge of general fusor is to handle possible diamond shape branches,
47 in the following graph, conv2d can be fused to elemwise add.
48
49 conv2d
50 / | \
51 / | \
52 op op op
53 \ | /
54 \ | /
55 elemwise add
56 |
57
58 However, at the point of conv2d we do not necessarily know that all the future paths
59 will merge at the elemwise add. The fusion algorithm applies post-dominator analysis.
60
61 The immediate post-dominator of a node defined by the closest node where all the future path goes
62 into. In the above case, the elemwise add is the post-dominator of conv2d. The general algorithm
63 is as follows:
64
65 - Construct a DAG of dataflow graph for dominator analysis
66 - Construct a post-dominator tree which gives immediate post dominator of each node.
67 - Run fusion algorithm with the given post-dominator information.
68
69 Note that, because we run analysis on a DAG, we use a single pass post-dominator
70 tree construction algorithm via LCA, which is simpler than the full version that handles cycles.
71
72 The fusion algorithm traverses from each node and checks if it can be fused to its
73 immediate post dominator. It has to check the following things:
74
75 - CheckPath: check all the path between a node and its immediate post-dominator
76 satisfies the fuse condition.
77 - Note that these intermediate node can already be fused with another nodes, the algorithm
78 will still run correctly.
79 - CommitFuse: mark all the nodes between source and post-dominator as the same group.
80 - We use an Union-Find data structure to manage the groups.
81*/
82using support::LinkedList;
83using support::LinkNode;
84
85constexpr uint32_t kMaxFusedOps = 256;
86
87static const Op& stop_fusion_op = Op::Get("annotation.stop_fusion");
88
89TVM_REGISTER_PASS_CONFIG_OPTION("relay.FuseOps.max_depth", Integer);
90TVM_REGISTER_PASS_CONFIG_OPTION("relay.FuseOps.link_params", Bool);
91
92// Creator of post dominator tree of the dataflow
93class IndexedForwardGraphCreator : private ExprVisitor {
94 public:
95 static IndexedForwardGraph Create(support::Arena* arena, const Expr& body) {
96 IndexedForwardGraphCreator creator(arena);
97 return creator.Prepare(body);
98 }
99
100 private:
101 explicit IndexedForwardGraphCreator(support::Arena* arena) : arena_(arena) {}
102
103 IndexedForwardGraph Prepare(const Expr& body) {
104 this->Update(body, nullptr, kOpaque);
105 this->VisitExpr(body);
106 return std::move(graph_);
107 }
108
109 private:
110 /*! \brief allocator of all the internal node object */
111 support::Arena* arena_;
112 // The output.
113 IndexedForwardGraph graph_;
114 // attribute equal comparator
115 StructuralEqual attr_equal_;
116 // Update the message stored at the node.
117 void Update(const Expr& node, IndexedForwardGraph::Node* parent, OpPatternKind pattern) {
118 const tvm::Object* key = node.get();
119 IndexedForwardGraph::Node* current;
120 auto it = graph_.node_map.find(key);
121 if (it != graph_.node_map.end()) {
122 current = it->second;
123 } else {
124 current = arena_->make<IndexedForwardGraph::Node>();
125 graph_.node_map[key] = current;
126 }
127 if (parent != nullptr) {
128 auto* link = arena_->make<LinkNode<IndexedForwardGraph::Edge>>();
129 link->value.node = parent;
130 link->value.pattern = pattern;
131 current->outputs.Push(link);
132 } else {
133 current->extern_ref = true;
134 }
135 }
136
137 void AddNode(const tvm::Object* key) {
138 auto it = graph_.node_map.find(key);
139 ICHECK(it != graph_.node_map.end()) << "Cannot find node " << GetRef<ObjectRef>(key);
140 IndexedForwardGraph::Node* node = it->second;
141 ICHECK(node->ref == nullptr);
142 node->ref = key;
143 node->index = graph_.post_dfs_order.size();
144 graph_.post_dfs_order.push_back(node);
145 }
146
147 // Post order tree
148 void VisitExpr_(const FunctionNode* op) final {
149 // Skip the function that should be handled by external codegen.
150 if (op->GetAttr<String>(attr::kCompiler).defined()) return;
151
152 for (auto param : op->params) {
153 this->Update(param, nullptr, kOpaque);
154 }
155 this->Update(op->body, nullptr, kOpaque);
156 ExprVisitor::VisitExpr_(op);
157 }
158
159 void VisitExpr_(const ConstantNode* op) final {
160 this->AddNode(op);
161 IndexedForwardGraph::Node* node = graph_.node_map.at(op);
162 DataType dtype = DataType(op->data->dtype);
163 // This rule must be consistent with code generator.
164 bool is_simple_const =
165 (dtype == DataType::Int(32) || dtype == DataType::Int(64) || dtype == DataType::Float(32) ||
166 dtype == DataType::Float(64) || dtype == DataType::Bool());
167 if (op->is_scalar() && is_simple_const) {
168 node->pattern = kElemWise;
169 } else {
170 // for now, mark non-scalar constant
171 // as opaque, we will not choose to fuse it.
172 node->pattern = kOpaque;
173 }
174 }
175
176 void VisitExpr_(const CallNode* call) final {
177 ICHECK(graph_.node_map.count(call));
178 IndexedForwardGraph::Node* node = graph_.node_map.at(call);
179 static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
180 // Now we set the pattern of this call.
181 //
182 // If we see a call mentioning an operator we should mark it with its
183 // annotated pattern.
184 //
185 // If the pattern is not annotated we will default to opaque.
186 //
187 // Finally if the operator position is not a call node we will
188 // need to call Update, as it may be an arbitrary expression.
189 OpPatternKind op_pattern = kOpaque;
190 if (const OpNode* opnode = call->op.as<OpNode>()) {
191 auto op = GetRef<Op>(opnode);
192 if (IsDynamic(call->checked_type()) && IsDataDependent(call)) {
193 // output of a shape func can't be fed to a data-dependent shape func
194 op_pattern = kOpaque;
195 } else {
196 op_pattern = static_cast<OpPatternKind>(fpattern[op]);
197 }
198 } else {
199 this->Update(call->op, node, kOpaque);
200 }
201
202 node->pattern = op_pattern;
203 this->Update(call->op, nullptr, kOpaque);
204 const auto* rtype = call->checked_type().as<TensorTypeNode>();
205 // pass the analysis back to all the children it references.
206 for (size_t i = 0; i < call->args.size(); ++i) {
207 const auto* arg_type = call->args[i]->checked_type().as<TensorTypeNode>();
208 // specifically check if result type is the same as arguments type
209 OpPatternKind edge_pattern = op_pattern;
210 if (edge_pattern == kBroadcast && arg_type != nullptr && rtype != nullptr &&
211 attr_equal_(rtype->shape, arg_type->shape)) {
212 edge_pattern = kElemWise;
213 }
214 this->Update(call->args[i], node, edge_pattern);
215 }
216 ExprVisitor::VisitExpr_(call);
217 this->AddNode(call);
218 }
219
220 void VisitExpr_(const TupleNode* op) final {
221 ICHECK(graph_.node_map.count(op));
222 IndexedForwardGraph::Node* tuple_node = graph_.node_map.at(op);
223 tuple_node->pattern = kTuple;
224 for (const Expr& field : op->fields) {
225 if (field->checked_type().as<TensorTypeNode>()) {
226 this->Update(field, tuple_node, kInjective);
227 } else {
228 this->Update(field, nullptr, kOpaque);
229 }
230 }
231 ExprVisitor::VisitExpr_(op);
232 this->AddNode(op);
233 }
234
235 void VisitExpr_(const TupleGetItemNode* op) final {
236 auto tuple_type = op->tuple->checked_type().as<TupleTypeNode>();
237 ICHECK(tuple_type);
238 // When TVM lowers a fused function, it expects all arguments to be a Tensor or
239 // a tuple containing only Tensors. But this tuple may contain a reference or
240 // another tuple. To avoid modifying codegen logic, we do not allow fusing through this node
241 // if the tuple contains such non Tensor fields. However, all fields will be recursively
242 // visited via call to ExprVisitor::VisitExpr_(op) below and corresponding visitor methods.
243 bool has_non_tensor = false;
244 for (auto ty : tuple_type->fields) {
245 if (!ty.as<TensorTypeNode>()) {
246 has_non_tensor = true;
247 break;
248 }
249 }
250 if (has_non_tensor) {
251 this->Update(op->tuple, nullptr, kOpaque);
252 } else {
253 ICHECK(graph_.node_map.count(op));
254 IndexedForwardGraph::Node* node = graph_.node_map.at(op);
255 node->pattern = kInjective;
256 this->Update(op->tuple, node, kInjective);
257 }
258 ExprVisitor::VisitExpr_(op);
259 this->AddNode(op);
260 }
261
262 void VisitExpr_(const VarNode* op) final { this->AddNode(op); }
263
264 void VisitExpr_(const LetNode* op) final {
265 // do not fuse through let.
266 auto pre_visit = [this](const LetNode* op) {
267 // Rely on the Memoizer to cache pre-visit values
268 this->Update(op->var, nullptr, kOpaque);
269 this->Update(op->value, nullptr, kOpaque);
270 this->Update(op->body, nullptr, kOpaque);
271 this->VisitExpr(op->var);
272 this->VisitExpr(op->value);
273 };
274 auto post_visit = [this](const LetNode* op) {
275 this->VisitExpr(op->body);
276 this->visit_counter_[op] += 1;
277 this->AddNode(op);
278 };
279 ExpandANormalForm(op, pre_visit, post_visit);
280 }
281
282 void VisitExpr_(const IfNode* op) final {
283 // do not fuse through if.
284 this->Update(op->cond, nullptr, kOpaque);
285 this->Update(op->true_branch, nullptr, kOpaque);
286 this->Update(op->false_branch, nullptr, kOpaque);
287 ExprVisitor::VisitExpr_(op);
288 this->AddNode(op);
289 }
290
291 void VisitExpr_(const RefCreateNode* op) final {
292 this->Update(op->value, nullptr, kOpaque);
293 ExprVisitor::VisitExpr_(op);
294 this->AddNode(op);
295 }
296
297 void VisitExpr_(const RefReadNode* op) final {
298 this->Update(op->ref, nullptr, kOpaque);
299 ExprVisitor::VisitExpr_(op);
300 this->AddNode(op);
301 }
302
303 void VisitExpr_(const RefWriteNode* op) final {
304 this->Update(op->ref, nullptr, kOpaque);
305 this->Update(op->value, nullptr, kOpaque);
306 ExprVisitor::VisitExpr_(op);
307 this->AddNode(op);
308 }
309
310 void VisitExpr_(const MatchNode* op) final {
311 this->Update(op->data, nullptr, kOpaque);
312 for (const Clause& c : op->clauses) {
313 this->Update(c->rhs, nullptr, kOpaque);
314 }
315 ExprVisitor::VisitExpr_(op);
316 this->AddNode(op);
317 }
318};
319
320class FuseMutator : private MixedModeMutator {
321 public:
322 FuseMutator(int fuse_opt_level, size_t max_fuse_depth, bool link_params)
323 : fuse_opt_level_(fuse_opt_level),
324 max_fuse_depth_(max_fuse_depth),
325 link_params_(link_params) {}
326
327 // Run the transform
328 Expr Transform(const Expr& body) {
329 return Transform(body, fuse_opt_level_, max_fuse_depth_, link_params_);
330 }
331
332 protected:
333 // Run the transform
334 Expr Transform(const Expr& body, int fuse_opt_level, size_t max_fuse_depth, bool link_params) {
335 // setup the group map.
336 auto graph = IndexedForwardGraphCreator::Create(&arena_, body);
337 auto groups = GraphPartitioner(&arena_, fuse_opt_level, max_fuse_depth).Partition(graph);
338 for (size_t nid = 0; nid < graph.post_dfs_order.size(); ++nid) {
339 ICHECK(graph.post_dfs_order[nid]->ref != nullptr);
340 gmap_[graph.post_dfs_order[nid]->ref] = groups[nid];
341 }
342 // The following line can be used for debug.
343 // this->DebugDumpGroup(body);
344 return this->Mutate(body);
345 }
346
347 private:
348 int fuse_opt_level_;
349 size_t max_fuse_depth_;
350 bool link_params_;
351
352 using MixedModeMutator::VisitExpr_;
353
354 /*! \brief Temporary information from each group. */
355 struct GroupInfo {
356 public:
357 // The parameters of the function.
358 Array<Var> params;
359 // The arguments to call the functions.
360 Array<Expr> arguments;
361 // Get a new parameter or allocate an old one
362 Var GetOrAllocParam(const Expr& expr, const Type& type) {
363 // run linear scan as most fused groups contain only a few inputs.
364 for (size_t i = 0; i < arguments.size(); ++i) {
365 if (expr.same_as(arguments[i])) return params[i];
366 }
367 // create a new parameter.
368 std::ostringstream os;
369 os << "p" << params.size();
370 auto var = Var(os.str(), type);
371 params.push_back(var);
372 arguments.push_back(expr);
373 return var;
374 }
375 };
376 /*! \brief Internal arena. */
377 support::Arena arena_;
378 /*! \brief The group assignment map. */
379 std::unordered_map<const Object*, GraphPartitioner::Group*> gmap_;
380 /* \brief Internal group information map. */
381 std::unordered_map<GraphPartitioner::Group*, GroupInfo> ginfo_;
382
383 // Skip primitive function.
384 Expr VisitExpr_(const FunctionNode* fn_node) {
385 if (fn_node->HasNonzeroAttr(attr::kPrimitive)) {
386 return GetRef<Expr>(fn_node);
387 } else {
388 return ExprMutator::VisitExpr_(fn_node);
389 }
390 }
391
392 // Transform calls.
393 Expr Rewrite_(const CallNode* call, const Expr& post) {
394 if (call->op.as<OpNode>()) {
395 static auto fnoncomputational = Op::GetAttrMap<TNonComputational>("TNonComputational");
396 static auto fqnncanonicalize = Op::GetAttrMap<FTVMLegalize>("FTVMQnnCanonicalize");
397
398 Op op = Downcast<Op>(call->op);
399 if (fnoncomputational.get(op, false) && !fqnncanonicalize.count(op)) {
400 return ExprMutator::VisitExpr_(call);
401 }
402
403 // If it is a primitive op call
404 // then we must have a group assignment for it already.
405 ICHECK(gmap_.count(call));
406 if (call->op == stop_fusion_op) {
407 return ExprMutator::VisitExpr(call->args[0]);
408 }
409 auto* ret_group = gmap_.at(call)->FindRoot();
410 Array<Expr> new_args = GetNewArguments(call->args, ret_group);
411
412 auto new_call = Call(call->op, new_args, call->attrs, call->type_args, call->span);
413
414 if (ret_group->root_ref == call) {
415 // This is the root of the group
416 // create the new call node.
417 return MakeNewFunction(ret_group, call->checked_type(), new_call);
418 } else {
419 // This is an intermediate node of a fused function
420 // simply return the new call.
421 return std::move(new_call);
422 }
423 } else {
424 return ExprMutator::VisitExpr_(call);
425 }
426 }
427
428 Expr Rewrite_(const TupleNode* tuple_node, const Expr& post) {
429 auto* ret_group = gmap_.at(tuple_node)->FindRoot();
430 if (ret_group->root_ref == tuple_node) {
431 return ExprMutator::VisitExpr_(tuple_node);
432 }
433 // This tuple is an intermediate node in the group
434 Array<Expr> new_fields = GetNewArguments(tuple_node->fields, ret_group);
435 return WithFields(GetRef<Tuple>(tuple_node), new_fields);
436 }
437
438 Expr Rewrite_(const TupleGetItemNode* tuple_get, const Expr& post) {
439 auto* ret_group = gmap_.at(tuple_get)->FindRoot();
440 auto new_tuple = GetNewArguments({tuple_get->tuple}, ret_group)[0];
441 auto new_node = TupleGetItem(new_tuple, tuple_get->index);
442 if (ret_group->root_ref == tuple_get) {
443 if (gmap_.at(tuple_get->tuple.get())->FindRoot() != ret_group) {
444 // Isolated. This case occurs when tuple is created by an Opaque op
445 // e.g. multibox_transform_loc
446 return ExprMutator::VisitExpr_(tuple_get);
447 }
448 // A new function whose output is a tuple field access
449 return MakeNewFunction(ret_group, tuple_get->checked_type(), new_node);
450 }
451 // This is an intermediate node in the group
452 return std::move(new_node);
453 }
454
455 Expr VisitExpr_(const LetNode* op) final {
456 auto pre_visit = [this](const LetNode* op) {
457 // Rely on the Memoizer to cache pre-visit values
458 this->VisitExpr(op->var);
459 this->VisitExpr(op->value);
460 };
461 auto post_visit = [this](const LetNode* op) {
462 // Rely on the Memoizer to cache pre-visit values
463 Var var = Downcast<Var>(this->VisitExpr(op->var));
464 Expr value = this->VisitExpr(op->value);
465 // Visit body and cache the op
466 Expr body = this->VisitExpr(op->body);
467 auto expr = GetRef<Expr>(op);
468 if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) {
469 this->memo_[expr] = expr;
470 } else {
471 this->memo_[expr] = Let(var, value, body);
472 }
473 };
474 ExpandANormalForm(op, pre_visit, post_visit);
475 return memo_[GetRef<Expr>(op)];
476 }
477
478 Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) {
479 // Quickly check special properties of the fused function.
480 // A pass to check if the fused op contains only reshape ops.
481 class CheckReshapeOnly : public ExprVisitor {
482 public:
483 void VisitExpr_(const CallNode* cn) final {
484 this->has_call = true;
485 static auto freshape_op = Op::GetAttrMap<TReshapeOp>("TReshapeOp");
486
487 if (!freshape_op.get(cn->op, false)) {
488 this->reshape_only = false;
489 }
490
491 if (!this->reshape_only) return;
492 ExprVisitor::VisitExpr_(cn);
493 }
494
495 void VisitExpr_(const VarNode* vn) final {
496 if (!vn->type_annotation.defined() || !vn->type_annotation->IsInstance<TensorTypeNode>()) {
497 this->reshape_only = false;
498 }
499 }
500
501 bool reshape_only = true;
502 bool has_call = false;
503 } visitor;
504
505 visitor(body);
506 const GroupInfo& ginfo = ginfo_[group];
507 auto func = Function(ginfo.params, body, ret_type, {});
508 func = WithAttr(std::move(func), attr::kPrimitive, tvm::Integer(visitor.has_call));
509 // TODO(mbs): "reshape" cleanup.
510 if (visitor.has_call && visitor.reshape_only) {
511 func = WithAttr(std::move(func), attr::kReshapeOnly, tvm::Integer(visitor.reshape_only));
512 }
513 return Call(func, ginfo.arguments, Attrs());
514 }
515
516 Array<Expr> GetNewArguments(const tvm::Array<Expr>& args,
517 GraphPartitioner::Group* current_group) {
518 Array<Expr> new_args;
519 for (auto arg : args) {
520 auto* arg_group = gmap_.at(arg.get())->FindRoot();
521 auto type = arg->checked_type();
522 Expr new_arg = this->Mutate(arg);
523 if (current_group != arg_group) {
524 if (!link_params_ || new_arg.as<ConstantNode>() == nullptr) {
525 Var param = ginfo_[current_group].GetOrAllocParam(new_arg, type);
526 new_args.push_back(param);
527 } else {
528 new_args.push_back(new_arg);
529 }
530 } else {
531 new_args.push_back(new_arg);
532 }
533 }
534 return new_args;
535 }
536
537 // Debug function, dump the group assignment in text.
538 void DebugDumpGroup(const Expr& body) {
539 std::string text = AsText(body, false, [this](const ObjectRef& expr) -> std::string {
540 auto it = gmap_.find(expr.get());
541 if (it == gmap_.end()) return "";
542 std::ostringstream os;
543 auto* group = it->second->FindRoot();
544 os << " /* group=" << group << " */";
545 return os.str();
546 });
547 LOG(INFO) << "Dump of group info:\n" << text;
548 }
549};
550
551Expr FuseOps(const Expr& expr, int fuse_opt_level, size_t max_fuse_depth, bool link_params,
552 const IRModule& module) {
553 return FuseMutator(fuse_opt_level, max_fuse_depth, link_params).Transform(expr);
554}
555
556namespace transform {
557
558Pass FuseOps(int fuse_opt_level) {
559 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
560 [=](Function f, IRModule m, PassContext pc) {
561 bool link_params = false;
562 Executor executor =
563 m->GetAttr<Executor>(tvm::attr::kExecutor).value_or(NullValue<Executor>());
564 link_params = executor.defined()
565 ? executor->attrs.GetAttr<Bool>("link-params").value_or(Bool(link_params))
566 : link_params;
567 link_params = pc->GetConfig("relay.FuseOps.link_params", Bool(link_params)).value();
568 int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level;
569 auto max_fuse_depth = pc->GetConfig("relay.FuseOps.max_depth", Integer(kMaxFusedOps));
570 return Downcast<Function>(
571 FuseOps(f, opt_level, max_fuse_depth.value().IntValue(), link_params, m));
572 };
573 return CreateFunctionPass(pass_func, 0, "FuseOps", {"InferType"});
574}
575
576TVM_REGISTER_GLOBAL("relay._transform.FuseOps").set_body_typed(FuseOps);
577
578} // namespace transform
579
580} // namespace relay
581} // namespace tvm
582