1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 *
22 * \file combine_parallel_op.cc
23 * \brief Abstract class to combine parallel ops and their successive element-wise ops.
24 */
25
26#include "combine_parallel_op.h"
27
28#include <tvm/node/structural_hash.h>
29#include <tvm/relay/analysis.h>
30#include <tvm/relay/attrs/nn.h>
31#include <tvm/relay/attrs/transform.h>
32#include <tvm/relay/expr_functor.h>
33#include <tvm/relay/op.h>
34#include <tvm/relay/op_attr_types.h>
35#include <tvm/relay/transform.h>
36
37#include <algorithm>
38#include <unordered_map>
39#include <unordered_set>
40#include <utility>
41
42#include "expr_subst.h"
43#include "pattern_utils.h"
44
45namespace tvm {
46namespace relay {
47
48BranchGroupFinder::BranchGroupFinder(const Op& op, FIsSupportedOp fis_supported_op,
49 FAreCompatibleOps fare_compatible_ops)
50 : cached_op_(op),
51 fis_supported_op_(fis_supported_op),
52 fare_compatible_ops_(fare_compatible_ops) {}
53
54std::vector<Group> BranchGroupFinder::Find(const Expr& expr) {
55 this->VisitExpr(expr);
56
57 std::vector<Group> groups;
58 for (const auto& root : op_roots_) {
59 const auto& children = children_map_.at(root);
60 size_t ngroups = groups.size();
61 for (const CallNode* child : children) {
62 if (child->op != cached_op_) continue;
63
64 auto&& branch = CreateBranch(child);
65 // add the branch to a group, or create a new group
66 auto it = std::find_if(groups.begin() + ngroups, groups.end(), [&](const Group& group) {
67 ICHECK(!group.empty() && !group[0].empty());
68 return fare_compatible_ops_(child, group[0][0]);
69 });
70 if (it != groups.end()) {
71 it->push_back(branch);
72 } else {
73 groups.emplace_back();
74 // each group has at least one branch
75 groups.back().push_back(branch);
76 }
77 }
78 }
79 return groups;
80}
81
82// Create a branch starting from op.
83Branch BranchGroupFinder::CreateBranch(const CallNode* op) {
84 auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
85 // each branch has at least one element, the first element is always op
86 Branch branch{op};
87 auto it = children_map_.find(GetRef<Expr>(branch.back()));
88 while (it != children_map_.end() && it->second.size() == 1) {
89 const CallNode* call = it->second[0];
90 auto pattern = fpattern[Downcast<Op>(call->op)];
91 if (pattern <= kBroadcast) {
92 branch.push_back(call);
93 it = children_map_.find(GetRef<Expr>(branch.back()));
94 } else {
95 break;
96 }
97 }
98 return branch;
99}
100
101void BranchGroupFinder::VisitExpr_(const CallNode* n) {
102 ExprVisitor::VisitExpr_(n);
103 if (n->op == cached_op_ && fis_supported_op_(n)) {
104 op_roots_.insert(n->args[0]);
105 children_map_[n->args[0]].push_back(n);
106 } else {
107 for (size_t i = 0; i < n->args.size(); i++) {
108 children_map_[n->args[i]].push_back(n);
109 }
110 }
111}
112
113ParallelOpCombiner::ParallelOpCombiner(const std::string& op_name, uint64_t min_num_branches)
114 : cached_op_(Op::Get(op_name)), min_num_branches_(min_num_branches) {}
115
116Expr ParallelOpCombiner::Combine(const Expr& expr) {
117 auto groups = BranchGroupFinder(
118 cached_op_, [&](const CallNode* n) { return IsSupportedOp(n); },
119 [&](const CallNode* a, const CallNode* b) { return CanOpsBeCombined(a, b); })
120 .Find(expr);
121 for (const Group& group : groups) {
122 if (group.size() < min_num_branches_) {
123 continue;
124 }
125 CombineBranches(group);
126 }
127 return ExprSubst(expr, std::move(subst_map_));
128}
129
130void ParallelOpCombiner::CombineBranches(const Group& branches) {
131 Call combined = MakeCombinedOp(branches);
132 auto it = std::min_element(branches.begin(), branches.end(),
133 [](const Branch& branch_a, const Branch& branch_b) {
134 return branch_a.size() < branch_b.size();
135 });
136 size_t depth = it->size();
137 size_t i;
138 // starting from 1 to skip the op
139 for (i = 1; i < depth; i++) {
140 size_t parent_index;
141 for (parent_index = 0; parent_index < branches[0][i]->args.size(); parent_index++) {
142 if (branches[0][i]->args[parent_index].get() == branches[0][i - 1]) break;
143 }
144 ICHECK_NE(parent_index, branches[0][i]->args.size());
145 if (!CheckLevel(branches, i, parent_index)) break;
146 combined = MakeCombinedCallFromFollowingOps(combined, branches, i, parent_index);
147 }
148 UpdateGroupOutput(combined, branches, i - 1, &subst_map_);
149}
150
151bool ParallelOpCombiner::CheckLevel(const Group& branches, size_t depth, size_t parent_index) {
152 const CallNode* call = branches[0][depth];
153 tvm::StructuralEqual attrs_equal;
154 // check if all branches in current depth can be combined
155 for (auto it = branches.begin() + 1; it != branches.end(); it++) {
156 const Branch& branch = *it;
157 if (!branch[depth]->op.same_as(call->op) || !attrs_equal(branch[depth]->attrs, call->attrs) ||
158 branch[depth]->args.size() != call->args.size()) {
159 return false;
160 }
161
162 if (branch[depth]->args[parent_index].get() != branch[depth - 1]) return false;
163
164 // Check args
165 for (size_t i = 0; i < call->args.size(); i++) {
166 if (i == parent_index) continue;
167
168 if (!IsArgCompatible(call, branch[depth], i) ||
169 !attrs_equal(call->attrs, branch[depth]->attrs)) {
170 return false;
171 }
172 }
173 }
174 return true;
175}
176
177} // namespace relay
178} // namespace tvm
179