1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/collage/candidate_partition.cc |
22 | * \brief A potential partition in the Collage search. |
23 | */ |
24 | |
25 | #include "./candidate_partition.h" |
26 | |
27 | #include <tvm/relay/analysis.h> |
28 | #include <tvm/relay/attrs/memory.h> |
29 | #include <tvm/relay/transform.h> |
30 | |
31 | #include "../transforms/compiler_function_utils.h" |
32 | #include "./candidate_function_cache.h" |
33 | #include "./candidate_set.h" |
34 | #include "./partition_rule.h" |
35 | #include "./partition_spec.h" |
36 | #include "./utils.h" |
37 | |
38 | namespace tvm { |
39 | namespace relay { |
40 | namespace collage { |
41 | |
42 | TVM_REGISTER_NODE_TYPE(CandidatePartitionNode); |
43 | |
44 | void CandidatePartitionNode::VisitAttrs(AttrVisitor* v) { |
45 | v->Visit("rule_name" , &rule_name_); |
46 | v->Visit("sub_graph" , &sub_graph_); |
47 | v->Visit("spec" , &spec_); |
48 | // TODO(mbs): cost_ |
49 | } |
50 | |
51 | PartitionSpec CandidatePartitionNode::partition_spec() const { |
52 | return Downcast<PartitionSpec>(spec_); |
53 | } |
54 | |
55 | std::string CandidatePartitionNode::partition_spec_name() const { |
56 | return Downcast<PartitionSpec>(spec_)->spec_name_; |
57 | } |
58 | |
59 | Target CandidatePartitionNode::target() const { return Downcast<PartitionSpec>(spec_)->target_; } |
60 | |
61 | std::string CandidatePartitionNode::ToSummary(const DataflowGraph& dataflow_graph) const { |
62 | std::ostringstream os; |
63 | os << sub_graph_->label_; |
64 | os << " | (" ; |
65 | bool first = true; |
66 | for (PostDfsIndex index : sub_graph_->input_) { |
67 | Expr sub_expr = dataflow_graph.index_to_node(index)->ref(); |
68 | if (CanInline(sub_expr)) { |
69 | continue; |
70 | } |
71 | if (first) { |
72 | first = false; |
73 | } else { |
74 | os << ", " ; |
75 | } |
76 | os << PrettyPrint(sub_expr->checked_type()); |
77 | } |
78 | os << ") -> (" ; |
79 | first = true; |
80 | for (PostDfsIndex index : sub_graph_->exit_) { |
81 | Expr sub_expr = dataflow_graph.index_to_node(index)->ref(); |
82 | if (CanInline(sub_expr)) { |
83 | continue; |
84 | } |
85 | if (first) { |
86 | first = false; |
87 | } else { |
88 | os << ", " ; |
89 | } |
90 | os << PrettyPrint(sub_expr->checked_type()); |
91 | } |
92 | os << ") | " ; |
93 | os << sub_graph_->inside_.ToString(); |
94 | os << " | " ; |
95 | os << partition_spec_name(); |
96 | os << " | " ; |
97 | os << cost_.ToString(); |
98 | return os.str(); |
99 | } |
100 | |
101 | std::string CandidatePartitionNode::ToString() const { |
102 | std::ostringstream os; |
103 | os << "{rule_name=" << rule_name_; |
104 | os << ",sub_graph=" << sub_graph_->ToString(); |
105 | os << ",spec_name=" << partition_spec_name(); |
106 | if (!cost_.is_unknown()) { |
107 | os << ",cost=" << cost_.ToString(); |
108 | } |
109 | os << "}" ; |
110 | return os.str(); |
111 | } |
112 | |
113 | namespace { |
114 | /*! |
115 | * \brief If function's body is a call to an inlined "Primitive" function, return it. |
116 | * Otherwise return function directly. |
117 | */ |
118 | Function GetPrimitiveFunction(const Function& function) { |
119 | if (const auto* call_node = function->body.as<CallNode>()) { |
120 | if (const auto* function_node = call_node->op.as<FunctionNode>()) { |
121 | if (function_node->HasNonzeroAttr(attr::kPrimitive)) { |
122 | return GetRef<Function>(function_node); |
123 | } |
124 | } |
125 | } |
126 | return function; |
127 | } |
128 | |
129 | /*! |
130 | * \brief Eta-expand any tuple arguments of \p function. Ie rewrite: |
131 | * \code |
132 | * f(x: (t1, t2)) { ... x ... } |
133 | * \endcode |
134 | * to |
135 | * \code |
136 | * f(x_1: t1, x_2: t2) { ... (x_1, x_2) ... } |
137 | * \endcode |
138 | */ |
139 | Function EtaExpandTuples(const Function& function) { |
140 | Map<Var, Expr> subst; |
141 | Array<Var> new_params; |
142 | for (const auto& param : function->params) { |
143 | std::vector<TensorType> tensor_types = FlattenTupleType(param->type_annotation); |
144 | if (tensor_types.size() == 1) { |
145 | new_params.push_back(param); |
146 | } else { |
147 | Array<Expr> fields; |
148 | for (size_t i = 0; i < tensor_types.size(); ++i) { |
149 | Var new_param(param->name_hint() + "_" + std::to_string(i), tensor_types[i], param->span); |
150 | new_param->checked_type_ = tensor_types[i]; |
151 | new_params.push_back(new_param); |
152 | fields.push_back(new_param); |
153 | } |
154 | Tuple new_tuple(fields); |
155 | subst.Set(param, new_tuple); |
156 | } |
157 | } |
158 | if (subst.empty()) { |
159 | return function; |
160 | } |
161 | return WithFields(function, new_params, Bind(function->body, subst)); |
162 | } |
163 | |
164 | } // namespace |
165 | |
166 | Cost CandidatePartitionNode::EstimatedCost( |
167 | const DataflowGraph& dataflow_graph, const CostEstimator& cost_estimator, |
168 | const std::shared_ptr<CandidateFunctionCache>& cache) const { |
169 | if (cost_.is_unknown()) { |
170 | VLOG_CONTEXT << "spec " << partition_spec_name(); |
171 | Function = sub_graph_->ExtractAsFunction(dataflow_graph); |
172 | VLOG(2) << "Extracted function:" << std::endl << PrettyPrint(extracted_function); |
173 | extracted_function = EtaExpandTuples(extracted_function); |
174 | VLOG(2) << "Validating function:" << std::endl << PrettyPrint(extracted_function); |
175 | String error = partition_spec()->validate_sub_graph_func_(extracted_function); |
176 | if (!error.empty()) { |
177 | cost_ = Cost::Invalid(); |
178 | VLOG(1) << "Unable to rewrite function: " << error; |
179 | } else { |
180 | // The extracted function may be the eta-expansion of a "Primitive" function. |
181 | // If so we want the cached external name and cost to be w.r.t. that function |
182 | // rather than the outer so that we'll get a cache hit when we outline functions |
183 | // in the final program. |
184 | Function primitive_function = GetPrimitiveFunction(extracted_function); |
185 | CandidateFunctionCache::Entry& entry = |
186 | cache->GetEntry(sub_graph_->label_, primitive_function); |
187 | if (entry.cost.is_unknown()) { |
188 | IRModule mod = IRModule::FromExpr(extracted_function); |
189 | VLOG(1) << "Outlining:" << std::endl << PrettyPrint(mod); |
190 | mod = OutlineCompilerFunctions(cache)(mod); |
191 | VLOG(1) << "Estimating cost of:" << std::endl |
192 | << PrettyPrint(mod) << std::endl |
193 | << "using target " << target()->ToDebugString(); |
194 | entry.cost = cost_estimator->Estimate(mod, target()); |
195 | VLOG(1) << "Measured cost as " << entry.cost.ToString(); |
196 | } else { |
197 | VLOG(1) << "Reusing cost " << entry.cost.ToString() |
198 | << " cached in candidate function cache" ; |
199 | } |
200 | cost_ = entry.cost; |
201 | } |
202 | } else { |
203 | VLOG(1) << "Reusing cost " << cost_.ToString() << " cached in candidate" ; |
204 | } |
205 | return cost_; |
206 | } |
207 | |
208 | CandidatePartition::CandidatePartition(String rule_name, SubGraph sub_graph, |
209 | ObjectRef /* actually PartitionSpec */ spec, Cost cost) { |
210 | auto node = runtime::make_object<CandidatePartitionNode>(); |
211 | node->rule_name_ = std::move(rule_name); |
212 | node->sub_graph_ = std::move(sub_graph); |
213 | node->spec_ = std::move(spec); |
214 | node->cost_ = cost; |
215 | data_ = std::move(node); |
216 | } |
217 | |
218 | CandidatePartition WithRuleName(CandidatePartition candidate, String rule_name) { |
219 | if (rule_name == candidate->rule_name_) { |
220 | return candidate; |
221 | } |
222 | auto* node = candidate.CopyOnWrite(); |
223 | node->rule_name_ = std::move(rule_name); |
224 | return GetRef<CandidatePartition>(node); |
225 | } |
226 | |
227 | CandidatePartition WithSubGraph(CandidatePartition candidate, SubGraph sub_graph) { |
228 | if (sub_graph == candidate->sub_graph_) { |
229 | return candidate; |
230 | } |
231 | auto* node = candidate.CopyOnWrite(); |
232 | node->sub_graph_ = std::move(sub_graph); |
233 | return GetRef<CandidatePartition>(node); |
234 | } |
235 | |
236 | bool CandidatePartition::operator<(const CandidatePartition& that) const { |
237 | // Order lexicographically on sub-graphs. |
238 | if (*get()->sub_graph_.get() < *that->sub_graph_.get()) { |
239 | return true; |
240 | } |
241 | if (*that->sub_graph_.get() < *get()->sub_graph_.get()) { |
242 | return false; |
243 | } |
244 | // Break ties by rule name. |
245 | return get()->rule_name_ < that->rule_name_; |
246 | } |
247 | |
248 | bool CandidatePartition::AreTouching(const DataflowGraph& dataflow_graph, |
249 | const CandidatePartition& that) const { |
250 | return get()->spec_ == that->spec_ && |
251 | get()->sub_graph_.AreTouching(dataflow_graph, that->sub_graph_); |
252 | } |
253 | |
254 | CandidatePartition CandidatePartition::DisjointUnion(const DataflowGraph& dataflow_graph, |
255 | const CandidatePartition& that) const { |
256 | ICHECK_EQ(get()->spec_, that->spec_); |
257 | return CandidatePartition(UnionLabels(get()->rule_name_, that->rule_name_), |
258 | get()->sub_graph_.DisjointUnion(dataflow_graph, that->sub_graph_), |
259 | get()->spec_, get()->cost_ + that->cost_); |
260 | } |
261 | |
262 | /*static*/ |
263 | CandidatePartition CandidatePartition::DisjointUnion(const DataflowGraph& dataflow_graph, |
264 | std::vector<CandidatePartition> candidates) { |
265 | ICHECK_GT(candidates.size(), 1); |
266 | CandidatePartition result = candidates.front(); |
267 | for (size_t i = 1; i < candidates.size(); ++i) { |
268 | result = result.DisjointUnion(dataflow_graph, candidates[i]); |
269 | } |
270 | return result; |
271 | } |
272 | |
273 | /*static*/ |
274 | Expr CandidatePartition::ParallelRewrite(const DataflowGraph& dataflow_graph, |
275 | const std::vector<CandidatePartition>& candidates) { |
276 | std::vector<SubGraph> sub_graphs; |
277 | sub_graphs.reserve(candidates.size()); |
278 | for (const auto& candidate : candidates) { |
279 | sub_graphs.emplace_back(candidate->sub_graph_); |
280 | } |
281 | return SubGraph::ParallelRewrite(dataflow_graph, sub_graphs); |
282 | } |
283 | |
284 | /*static*/ |
285 | std::vector<CandidatePartition> CandidatePartition::MaxCoalesce( |
286 | const DataflowGraph& dataflow_graph, std::vector<CandidatePartition> candidates) { |
287 | VLOG(1) << "Running MaxCoalesce over " << candidates.size() << " candidates" ; |
288 | // This is an eager version of using the simple (kOpaque, kOpaque) combiner. |
289 | |
290 | // Switch to set representation. |
291 | CandidateSet result_set(std::move(candidates)); |
292 | |
293 | // Until fixed point... |
294 | size_t num_rounds = 0; |
295 | while (result_set.PrepareForNextRound()) { |
296 | VLOG_CONTEXT << "round " << ++num_rounds; |
297 | VLOG(1) << "checking " << result_set.size() << " candidates (" << result_set.first_new_index() |
298 | << " existing)" ; |
299 | IndexSet removed_this_round(result_set.size()); // over candidate indexes! |
300 | |
301 | // Build map from post-dfs indices to the indices of candidates with corresponding entry node. |
302 | // NOTE: the index set is over candidate indices not post-dfs indices! |
303 | std::vector<IndexSet> entry_map(dataflow_graph.size(), IndexSet(result_set.size())); |
304 | for (size_t i = 0; i < result_set.size(); ++i) { |
305 | CandidatePartition candidate = result_set.at(i); |
306 | for (PostDfsIndex entry_index : candidate->sub_graph_->entry_) { |
307 | entry_map[entry_index].Add(i); |
308 | } |
309 | } |
310 | |
311 | for (size_t i = 0; i < result_set.size(); ++i) { |
312 | if (removed_this_round[i]) { |
313 | // Already merged. |
314 | continue; |
315 | } |
316 | CandidatePartition upstream = result_set.at(i); |
317 | // Narrow our search to just those candidates which could touch. |
318 | IndexSet possible_downstream(result_set.size()); // over candidate indexes! |
319 | for (PostDfsIndex output_index : upstream->sub_graph_->output_) { |
320 | possible_downstream = possible_downstream | entry_map[output_index]; |
321 | } |
322 | for (size_t j : possible_downstream) { |
323 | if (removed_this_round[j]) { |
324 | // Already merged. |
325 | continue; |
326 | } |
327 | if (i == j) { |
328 | // Ignore self. |
329 | continue; |
330 | } |
331 | CandidatePartition downstream = result_set.at(j); |
332 | if (!upstream.AreTouching(dataflow_graph, downstream)) { |
333 | continue; |
334 | } |
335 | CandidatePartition new_candidate = upstream.DisjointUnion(dataflow_graph, downstream); |
336 | VLOG(2) << "Merging upstream candidate " << upstream->ToString() |
337 | << " and downstream candidate " << downstream->ToString() << " to yield " |
338 | << new_candidate->ToString(); |
339 | result_set.Add(dataflow_graph, new_candidate); |
340 | result_set.Remove(upstream); |
341 | removed_this_round.Add(i); |
342 | result_set.Remove(downstream); |
343 | removed_this_round.Add(j); |
344 | } |
345 | } |
346 | } |
347 | |
348 | // Restore canonical order. |
349 | result_set.sort(); |
350 | |
351 | VLOG(1) << "MaxCoalesce produced " << result_set.size() << " candidates" ; |
352 | return result_set.MovedCurrentCandidates(); |
353 | } |
354 | |
355 | } // namespace collage |
356 | } // namespace relay |
357 | } // namespace tvm |
358 | |