1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/collage/candidate_partition.cc
22 * \brief A potential partition in the Collage search.
23 */
24
25#include "./candidate_partition.h"
26
27#include <tvm/relay/analysis.h>
28#include <tvm/relay/attrs/memory.h>
29#include <tvm/relay/transform.h>
30
31#include "../transforms/compiler_function_utils.h"
32#include "./candidate_function_cache.h"
33#include "./candidate_set.h"
34#include "./partition_rule.h"
35#include "./partition_spec.h"
36#include "./utils.h"
37
38namespace tvm {
39namespace relay {
40namespace collage {
41
42TVM_REGISTER_NODE_TYPE(CandidatePartitionNode);
43
44void CandidatePartitionNode::VisitAttrs(AttrVisitor* v) {
45 v->Visit("rule_name", &rule_name_);
46 v->Visit("sub_graph", &sub_graph_);
47 v->Visit("spec", &spec_);
48 // TODO(mbs): cost_
49}
50
51PartitionSpec CandidatePartitionNode::partition_spec() const {
52 return Downcast<PartitionSpec>(spec_);
53}
54
55std::string CandidatePartitionNode::partition_spec_name() const {
56 return Downcast<PartitionSpec>(spec_)->spec_name_;
57}
58
59Target CandidatePartitionNode::target() const { return Downcast<PartitionSpec>(spec_)->target_; }
60
61std::string CandidatePartitionNode::ToSummary(const DataflowGraph& dataflow_graph) const {
62 std::ostringstream os;
63 os << sub_graph_->label_;
64 os << " | (";
65 bool first = true;
66 for (PostDfsIndex index : sub_graph_->input_) {
67 Expr sub_expr = dataflow_graph.index_to_node(index)->ref();
68 if (CanInline(sub_expr)) {
69 continue;
70 }
71 if (first) {
72 first = false;
73 } else {
74 os << ", ";
75 }
76 os << PrettyPrint(sub_expr->checked_type());
77 }
78 os << ") -> (";
79 first = true;
80 for (PostDfsIndex index : sub_graph_->exit_) {
81 Expr sub_expr = dataflow_graph.index_to_node(index)->ref();
82 if (CanInline(sub_expr)) {
83 continue;
84 }
85 if (first) {
86 first = false;
87 } else {
88 os << ", ";
89 }
90 os << PrettyPrint(sub_expr->checked_type());
91 }
92 os << ") | ";
93 os << sub_graph_->inside_.ToString();
94 os << " | ";
95 os << partition_spec_name();
96 os << " | ";
97 os << cost_.ToString();
98 return os.str();
99}
100
101std::string CandidatePartitionNode::ToString() const {
102 std::ostringstream os;
103 os << "{rule_name=" << rule_name_;
104 os << ",sub_graph=" << sub_graph_->ToString();
105 os << ",spec_name=" << partition_spec_name();
106 if (!cost_.is_unknown()) {
107 os << ",cost=" << cost_.ToString();
108 }
109 os << "}";
110 return os.str();
111}
112
113namespace {
114/*!
115 * \brief If function's body is a call to an inlined "Primitive" function, return it.
116 * Otherwise return function directly.
117 */
118Function GetPrimitiveFunction(const Function& function) {
119 if (const auto* call_node = function->body.as<CallNode>()) {
120 if (const auto* function_node = call_node->op.as<FunctionNode>()) {
121 if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
122 return GetRef<Function>(function_node);
123 }
124 }
125 }
126 return function;
127}
128
129/*!
130 * \brief Eta-expand any tuple arguments of \p function. Ie rewrite:
131 * \code
132 * f(x: (t1, t2)) { ... x ... }
133 * \endcode
134 * to
135 * \code
136 * f(x_1: t1, x_2: t2) { ... (x_1, x_2) ... }
137 * \endcode
138 */
139Function EtaExpandTuples(const Function& function) {
140 Map<Var, Expr> subst;
141 Array<Var> new_params;
142 for (const auto& param : function->params) {
143 std::vector<TensorType> tensor_types = FlattenTupleType(param->type_annotation);
144 if (tensor_types.size() == 1) {
145 new_params.push_back(param);
146 } else {
147 Array<Expr> fields;
148 for (size_t i = 0; i < tensor_types.size(); ++i) {
149 Var new_param(param->name_hint() + "_" + std::to_string(i), tensor_types[i], param->span);
150 new_param->checked_type_ = tensor_types[i];
151 new_params.push_back(new_param);
152 fields.push_back(new_param);
153 }
154 Tuple new_tuple(fields);
155 subst.Set(param, new_tuple);
156 }
157 }
158 if (subst.empty()) {
159 return function;
160 }
161 return WithFields(function, new_params, Bind(function->body, subst));
162}
163
164} // namespace
165
166Cost CandidatePartitionNode::EstimatedCost(
167 const DataflowGraph& dataflow_graph, const CostEstimator& cost_estimator,
168 const std::shared_ptr<CandidateFunctionCache>& cache) const {
169 if (cost_.is_unknown()) {
170 VLOG_CONTEXT << "spec " << partition_spec_name();
171 Function extracted_function = sub_graph_->ExtractAsFunction(dataflow_graph);
172 VLOG(2) << "Extracted function:" << std::endl << PrettyPrint(extracted_function);
173 extracted_function = EtaExpandTuples(extracted_function);
174 VLOG(2) << "Validating function:" << std::endl << PrettyPrint(extracted_function);
175 String error = partition_spec()->validate_sub_graph_func_(extracted_function);
176 if (!error.empty()) {
177 cost_ = Cost::Invalid();
178 VLOG(1) << "Unable to rewrite function: " << error;
179 } else {
180 // The extracted function may be the eta-expansion of a "Primitive" function.
181 // If so we want the cached external name and cost to be w.r.t. that function
182 // rather than the outer so that we'll get a cache hit when we outline functions
183 // in the final program.
184 Function primitive_function = GetPrimitiveFunction(extracted_function);
185 CandidateFunctionCache::Entry& entry =
186 cache->GetEntry(sub_graph_->label_, primitive_function);
187 if (entry.cost.is_unknown()) {
188 IRModule mod = IRModule::FromExpr(extracted_function);
189 VLOG(1) << "Outlining:" << std::endl << PrettyPrint(mod);
190 mod = OutlineCompilerFunctions(cache)(mod);
191 VLOG(1) << "Estimating cost of:" << std::endl
192 << PrettyPrint(mod) << std::endl
193 << "using target " << target()->ToDebugString();
194 entry.cost = cost_estimator->Estimate(mod, target());
195 VLOG(1) << "Measured cost as " << entry.cost.ToString();
196 } else {
197 VLOG(1) << "Reusing cost " << entry.cost.ToString()
198 << " cached in candidate function cache";
199 }
200 cost_ = entry.cost;
201 }
202 } else {
203 VLOG(1) << "Reusing cost " << cost_.ToString() << " cached in candidate";
204 }
205 return cost_;
206}
207
208CandidatePartition::CandidatePartition(String rule_name, SubGraph sub_graph,
209 ObjectRef /* actually PartitionSpec */ spec, Cost cost) {
210 auto node = runtime::make_object<CandidatePartitionNode>();
211 node->rule_name_ = std::move(rule_name);
212 node->sub_graph_ = std::move(sub_graph);
213 node->spec_ = std::move(spec);
214 node->cost_ = cost;
215 data_ = std::move(node);
216}
217
218CandidatePartition WithRuleName(CandidatePartition candidate, String rule_name) {
219 if (rule_name == candidate->rule_name_) {
220 return candidate;
221 }
222 auto* node = candidate.CopyOnWrite();
223 node->rule_name_ = std::move(rule_name);
224 return GetRef<CandidatePartition>(node);
225}
226
227CandidatePartition WithSubGraph(CandidatePartition candidate, SubGraph sub_graph) {
228 if (sub_graph == candidate->sub_graph_) {
229 return candidate;
230 }
231 auto* node = candidate.CopyOnWrite();
232 node->sub_graph_ = std::move(sub_graph);
233 return GetRef<CandidatePartition>(node);
234}
235
236bool CandidatePartition::operator<(const CandidatePartition& that) const {
237 // Order lexicographically on sub-graphs.
238 if (*get()->sub_graph_.get() < *that->sub_graph_.get()) {
239 return true;
240 }
241 if (*that->sub_graph_.get() < *get()->sub_graph_.get()) {
242 return false;
243 }
244 // Break ties by rule name.
245 return get()->rule_name_ < that->rule_name_;
246}
247
248bool CandidatePartition::AreTouching(const DataflowGraph& dataflow_graph,
249 const CandidatePartition& that) const {
250 return get()->spec_ == that->spec_ &&
251 get()->sub_graph_.AreTouching(dataflow_graph, that->sub_graph_);
252}
253
254CandidatePartition CandidatePartition::DisjointUnion(const DataflowGraph& dataflow_graph,
255 const CandidatePartition& that) const {
256 ICHECK_EQ(get()->spec_, that->spec_);
257 return CandidatePartition(UnionLabels(get()->rule_name_, that->rule_name_),
258 get()->sub_graph_.DisjointUnion(dataflow_graph, that->sub_graph_),
259 get()->spec_, get()->cost_ + that->cost_);
260}
261
262/*static*/
263CandidatePartition CandidatePartition::DisjointUnion(const DataflowGraph& dataflow_graph,
264 std::vector<CandidatePartition> candidates) {
265 ICHECK_GT(candidates.size(), 1);
266 CandidatePartition result = candidates.front();
267 for (size_t i = 1; i < candidates.size(); ++i) {
268 result = result.DisjointUnion(dataflow_graph, candidates[i]);
269 }
270 return result;
271}
272
273/*static*/
274Expr CandidatePartition::ParallelRewrite(const DataflowGraph& dataflow_graph,
275 const std::vector<CandidatePartition>& candidates) {
276 std::vector<SubGraph> sub_graphs;
277 sub_graphs.reserve(candidates.size());
278 for (const auto& candidate : candidates) {
279 sub_graphs.emplace_back(candidate->sub_graph_);
280 }
281 return SubGraph::ParallelRewrite(dataflow_graph, sub_graphs);
282}
283
284/*static*/
285std::vector<CandidatePartition> CandidatePartition::MaxCoalesce(
286 const DataflowGraph& dataflow_graph, std::vector<CandidatePartition> candidates) {
287 VLOG(1) << "Running MaxCoalesce over " << candidates.size() << " candidates";
288 // This is an eager version of using the simple (kOpaque, kOpaque) combiner.
289
290 // Switch to set representation.
291 CandidateSet result_set(std::move(candidates));
292
293 // Until fixed point...
294 size_t num_rounds = 0;
295 while (result_set.PrepareForNextRound()) {
296 VLOG_CONTEXT << "round " << ++num_rounds;
297 VLOG(1) << "checking " << result_set.size() << " candidates (" << result_set.first_new_index()
298 << " existing)";
299 IndexSet removed_this_round(result_set.size()); // over candidate indexes!
300
301 // Build map from post-dfs indices to the indices of candidates with corresponding entry node.
302 // NOTE: the index set is over candidate indices not post-dfs indices!
303 std::vector<IndexSet> entry_map(dataflow_graph.size(), IndexSet(result_set.size()));
304 for (size_t i = 0; i < result_set.size(); ++i) {
305 CandidatePartition candidate = result_set.at(i);
306 for (PostDfsIndex entry_index : candidate->sub_graph_->entry_) {
307 entry_map[entry_index].Add(i);
308 }
309 }
310
311 for (size_t i = 0; i < result_set.size(); ++i) {
312 if (removed_this_round[i]) {
313 // Already merged.
314 continue;
315 }
316 CandidatePartition upstream = result_set.at(i);
317 // Narrow our search to just those candidates which could touch.
318 IndexSet possible_downstream(result_set.size()); // over candidate indexes!
319 for (PostDfsIndex output_index : upstream->sub_graph_->output_) {
320 possible_downstream = possible_downstream | entry_map[output_index];
321 }
322 for (size_t j : possible_downstream) {
323 if (removed_this_round[j]) {
324 // Already merged.
325 continue;
326 }
327 if (i == j) {
328 // Ignore self.
329 continue;
330 }
331 CandidatePartition downstream = result_set.at(j);
332 if (!upstream.AreTouching(dataflow_graph, downstream)) {
333 continue;
334 }
335 CandidatePartition new_candidate = upstream.DisjointUnion(dataflow_graph, downstream);
336 VLOG(2) << "Merging upstream candidate " << upstream->ToString()
337 << " and downstream candidate " << downstream->ToString() << " to yield "
338 << new_candidate->ToString();
339 result_set.Add(dataflow_graph, new_candidate);
340 result_set.Remove(upstream);
341 removed_this_round.Add(i);
342 result_set.Remove(downstream);
343 removed_this_round.Add(j);
344 }
345 }
346 }
347
348 // Restore canonical order.
349 result_set.sort();
350
351 VLOG(1) << "MaxCoalesce produced " << result_set.size() << " candidates";
352 return result_set.MovedCurrentCandidates();
353}
354
355} // namespace collage
356} // namespace relay
357} // namespace tvm
358