1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file relay/collage/candidate_partition_index.h |
22 | * \brief Index for finding relevant candidate partitions for a particular search state. |
23 | */ |
24 | |
25 | #include "./candidate_partition_index.h" |
26 | |
27 | #include "./gather_partition_specs.h" |
28 | #include "./prune_candidates.h" |
29 | #include "./utils.h" |
30 | |
31 | namespace tvm { |
32 | namespace relay { |
33 | namespace collage { |
34 | |
35 | CandidatePartitionIndex::CandidatePartitionIndex( |
36 | const std::unordered_map<const ExprNode*, VirtualDevice>* virtual_devices, |
37 | DataflowGraph* dataflow_graph) |
38 | : virtual_devices_(virtual_devices), |
39 | dataflow_graph_(dataflow_graph), |
40 | first_inside_index_to_candidates_(dataflow_graph->size()) {} |
41 | |
42 | void CandidatePartitionIndex::Index(const Array<PartitionSpec>& partition_specs) { |
43 | std::vector<CandidatePartition> candidates = Collect(partition_specs); |
44 | candidates = PruneCandidates(*dataflow_graph_, candidates); |
45 | // Index the candidates by their first inside index. |
46 | for (auto& candidate : candidates) { |
47 | first_inside_index_to_candidates_[candidate->sub_graph_->first_inside_index_].emplace_back( |
48 | candidate); |
49 | } |
50 | size_ = candidates.size(); |
51 | } |
52 | |
53 | void CandidatePartitionIndex::EstimateAllCosts( |
54 | const CostEstimator cost_estimator, const std::shared_ptr<CandidateFunctionCache>& cache) { |
55 | size_t n = 0; |
56 | for (PostDfsIndex index = 0; index < dataflow_graph_->size(); ++index) { |
57 | for (const auto& candidate : first_inside_index_to_candidates_[index]) { |
58 | LOG(INFO) << "Estimating cost of candidate " << candidate->ToSummary(*dataflow_graph_) << " [" |
59 | << n++ << "/" << size_ << "]" ; |
60 | // Cost will be cached in candidate as a side effect. |
61 | Cost cost = candidate->EstimatedCost(*dataflow_graph_, cost_estimator, cache); |
62 | LOG(INFO) << "Candidate has cost " << cost.ToString(); |
63 | } |
64 | } |
65 | } |
66 | |
67 | std::string CandidatePartitionIndex::ToSummary() const { |
68 | std::vector<std::string> lines; |
69 | for (const auto& candidates : first_inside_index_to_candidates_) { |
70 | for (const auto& candidate : candidates) { |
71 | if (candidate->partition_spec_name() == kHostSpecName) { |
72 | continue; |
73 | } |
74 | lines.emplace_back(candidate->ToSummary(*dataflow_graph_)); |
75 | } |
76 | } |
77 | std::sort(lines.begin(), lines.end()); |
78 | std::ostringstream os; |
79 | bool first = true; |
80 | for (const auto& line : lines) { |
81 | if (first) { |
82 | first = false; |
83 | } else { |
84 | os << std::endl; |
85 | } |
86 | os << line; |
87 | } |
88 | return os.str(); |
89 | } |
90 | |
91 | bool CandidatePartitionIndex::IsCompatibleWithVirtualDevice(const CandidatePartition& candidate) { |
92 | for (PostDfsIndex index : candidate->sub_graph_->inside_) { |
93 | const ExprNode* sub_expr_node = dataflow_graph_->index_to_node(index)->node_ref_; |
94 | if (sub_expr_node->IsInstance<OpNode>() || sub_expr_node->IsInstance<ConstructorNode>()) { |
95 | // These nodes are target/device polymorphic. |
96 | continue; |
97 | } |
98 | auto itr = virtual_devices_->find(sub_expr_node); |
99 | ICHECK(itr != virtual_devices_->end()) << PrettyPrint(GetRef<Expr>(sub_expr_node)); |
100 | const Target& existing_target = itr->second->target; |
101 | if (!existing_target.defined()) { |
102 | // No constraint. |
103 | continue; |
104 | } |
105 | if (StructuralEqual()(existing_target, candidate->target())) { |
106 | // No disagreement. |
107 | continue; |
108 | } |
109 | if (!candidate->target().IsExternalCodegenFor(itr->second->target)) { |
110 | // The candidate's target is not an external codegen target compatible with the existing |
111 | // target. |
112 | // TODO(mbs): There's a conflict here between Collage's desire to leave some expression nodes |
113 | // 'behind' on the VM and PlanDevice's desire to assign a primitive Target to every node. |
114 | // I think PlanDevices is the one that needs to give here by leaving such nodes |
115 | // unconstrained. |
116 | VLOG(1) << "Ignoring candidate " << candidate->ToString() |
117 | << " since incompatible with existing virtual device assignment of:" << std::endl |
118 | << itr->second << std::endl |
119 | << "to sub-graph:" << std::endl |
120 | << PrettyPrint(GetRef<Expr>(sub_expr_node)); |
121 | return false; |
122 | } |
123 | } |
124 | return true; |
125 | } |
126 | |
127 | std::vector<CandidatePartition> CandidatePartitionIndex::Collect( |
128 | const Array<PartitionSpec>& partition_specs) { |
129 | VLOG_CONTEXT << "collecting" ; |
130 | std::vector<CandidatePartition> result; |
131 | for (const auto& spec : partition_specs) { |
132 | VLOG_CONTEXT << "spec " << spec->spec_name_; |
133 | VLOG(1) << "collecting candidates" ; |
134 | std::vector<CandidatePartition> candidates = spec->AllCandidates(*dataflow_graph_); |
135 | for (auto& candidate : candidates) { |
136 | if (!IsCompatibleWithVirtualDevice(candidate)) { |
137 | continue; |
138 | } |
139 | result.push_back(candidate); |
140 | } |
141 | } |
142 | VLOG(1) << "Found " << result.size() << " candidates" ; |
143 | return result; |
144 | } |
145 | |
146 | } // namespace collage |
147 | } // namespace relay |
148 | } // namespace tvm |
149 | |