1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * |
22 | * \file combine_parallel_op.cc |
23 | * \brief Abstract class to combine parallel ops and their successive element-wise ops. |
24 | */ |
25 | |
26 | #include "combine_parallel_op.h" |
27 | |
28 | #include <tvm/node/structural_hash.h> |
29 | #include <tvm/relay/analysis.h> |
30 | #include <tvm/relay/attrs/nn.h> |
31 | #include <tvm/relay/attrs/transform.h> |
32 | #include <tvm/relay/expr_functor.h> |
33 | #include <tvm/relay/op.h> |
34 | #include <tvm/relay/op_attr_types.h> |
35 | #include <tvm/relay/transform.h> |
36 | |
37 | #include <algorithm> |
38 | #include <unordered_map> |
39 | #include <unordered_set> |
40 | #include <utility> |
41 | |
42 | #include "expr_subst.h" |
43 | #include "pattern_utils.h" |
44 | |
45 | namespace tvm { |
46 | namespace relay { |
47 | |
48 | BranchGroupFinder::BranchGroupFinder(const Op& op, FIsSupportedOp fis_supported_op, |
49 | FAreCompatibleOps fare_compatible_ops) |
50 | : cached_op_(op), |
51 | fis_supported_op_(fis_supported_op), |
52 | fare_compatible_ops_(fare_compatible_ops) {} |
53 | |
54 | std::vector<Group> BranchGroupFinder::Find(const Expr& expr) { |
55 | this->VisitExpr(expr); |
56 | |
57 | std::vector<Group> groups; |
58 | for (const auto& root : op_roots_) { |
59 | const auto& children = children_map_.at(root); |
60 | size_t ngroups = groups.size(); |
61 | for (const CallNode* child : children) { |
62 | if (child->op != cached_op_) continue; |
63 | |
64 | auto&& branch = CreateBranch(child); |
65 | // add the branch to a group, or create a new group |
66 | auto it = std::find_if(groups.begin() + ngroups, groups.end(), [&](const Group& group) { |
67 | ICHECK(!group.empty() && !group[0].empty()); |
68 | return fare_compatible_ops_(child, group[0][0]); |
69 | }); |
70 | if (it != groups.end()) { |
71 | it->push_back(branch); |
72 | } else { |
73 | groups.emplace_back(); |
74 | // each group has at least one branch |
75 | groups.back().push_back(branch); |
76 | } |
77 | } |
78 | } |
79 | return groups; |
80 | } |
81 | |
82 | // Create a branch starting from op. |
83 | Branch BranchGroupFinder::CreateBranch(const CallNode* op) { |
84 | auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern" ); |
85 | // each branch has at least one element, the first element is always op |
86 | Branch branch{op}; |
87 | auto it = children_map_.find(GetRef<Expr>(branch.back())); |
88 | while (it != children_map_.end() && it->second.size() == 1) { |
89 | const CallNode* call = it->second[0]; |
90 | auto pattern = fpattern[Downcast<Op>(call->op)]; |
91 | if (pattern <= kBroadcast) { |
92 | branch.push_back(call); |
93 | it = children_map_.find(GetRef<Expr>(branch.back())); |
94 | } else { |
95 | break; |
96 | } |
97 | } |
98 | return branch; |
99 | } |
100 | |
101 | void BranchGroupFinder::VisitExpr_(const CallNode* n) { |
102 | ExprVisitor::VisitExpr_(n); |
103 | if (n->op == cached_op_ && fis_supported_op_(n)) { |
104 | op_roots_.insert(n->args[0]); |
105 | children_map_[n->args[0]].push_back(n); |
106 | } else { |
107 | for (size_t i = 0; i < n->args.size(); i++) { |
108 | children_map_[n->args[i]].push_back(n); |
109 | } |
110 | } |
111 | } |
112 | |
113 | ParallelOpCombiner::ParallelOpCombiner(const std::string& op_name, uint64_t min_num_branches) |
114 | : cached_op_(Op::Get(op_name)), min_num_branches_(min_num_branches) {} |
115 | |
116 | Expr ParallelOpCombiner::Combine(const Expr& expr) { |
117 | auto groups = BranchGroupFinder( |
118 | cached_op_, [&](const CallNode* n) { return IsSupportedOp(n); }, |
119 | [&](const CallNode* a, const CallNode* b) { return CanOpsBeCombined(a, b); }) |
120 | .Find(expr); |
121 | for (const Group& group : groups) { |
122 | if (group.size() < min_num_branches_) { |
123 | continue; |
124 | } |
125 | CombineBranches(group); |
126 | } |
127 | return ExprSubst(expr, std::move(subst_map_)); |
128 | } |
129 | |
130 | void ParallelOpCombiner::CombineBranches(const Group& branches) { |
131 | Call combined = MakeCombinedOp(branches); |
132 | auto it = std::min_element(branches.begin(), branches.end(), |
133 | [](const Branch& branch_a, const Branch& branch_b) { |
134 | return branch_a.size() < branch_b.size(); |
135 | }); |
136 | size_t depth = it->size(); |
137 | size_t i; |
138 | // starting from 1 to skip the op |
139 | for (i = 1; i < depth; i++) { |
140 | size_t parent_index; |
141 | for (parent_index = 0; parent_index < branches[0][i]->args.size(); parent_index++) { |
142 | if (branches[0][i]->args[parent_index].get() == branches[0][i - 1]) break; |
143 | } |
144 | ICHECK_NE(parent_index, branches[0][i]->args.size()); |
145 | if (!CheckLevel(branches, i, parent_index)) break; |
146 | combined = MakeCombinedCallFromFollowingOps(combined, branches, i, parent_index); |
147 | } |
148 | UpdateGroupOutput(combined, branches, i - 1, &subst_map_); |
149 | } |
150 | |
151 | bool ParallelOpCombiner::CheckLevel(const Group& branches, size_t depth, size_t parent_index) { |
152 | const CallNode* call = branches[0][depth]; |
153 | tvm::StructuralEqual attrs_equal; |
154 | // check if all branches in current depth can be combined |
155 | for (auto it = branches.begin() + 1; it != branches.end(); it++) { |
156 | const Branch& branch = *it; |
157 | if (!branch[depth]->op.same_as(call->op) || !attrs_equal(branch[depth]->attrs, call->attrs) || |
158 | branch[depth]->args.size() != call->args.size()) { |
159 | return false; |
160 | } |
161 | |
162 | if (branch[depth]->args[parent_index].get() != branch[depth - 1]) return false; |
163 | |
164 | // Check args |
165 | for (size_t i = 0; i < call->args.size(); i++) { |
166 | if (i == parent_index) continue; |
167 | |
168 | if (!IsArgCompatible(call, branch[depth], i) || |
169 | !attrs_equal(call->attrs, branch[depth]->attrs)) { |
170 | return false; |
171 | } |
172 | } |
173 | } |
174 | return true; |
175 | } |
176 | |
177 | } // namespace relay |
178 | } // namespace tvm |
179 | |