1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file relay/collage/candidate_partition_index.h
22 * \brief Index for finding relevant candidate partitions for a particular search state.
23 */
24
25#include "./candidate_partition_index.h"
26
27#include "./gather_partition_specs.h"
28#include "./prune_candidates.h"
29#include "./utils.h"
30
31namespace tvm {
32namespace relay {
33namespace collage {
34
35CandidatePartitionIndex::CandidatePartitionIndex(
36 const std::unordered_map<const ExprNode*, VirtualDevice>* virtual_devices,
37 DataflowGraph* dataflow_graph)
38 : virtual_devices_(virtual_devices),
39 dataflow_graph_(dataflow_graph),
40 first_inside_index_to_candidates_(dataflow_graph->size()) {}
41
42void CandidatePartitionIndex::Index(const Array<PartitionSpec>& partition_specs) {
43 std::vector<CandidatePartition> candidates = Collect(partition_specs);
44 candidates = PruneCandidates(*dataflow_graph_, candidates);
45 // Index the candidates by their first inside index.
46 for (auto& candidate : candidates) {
47 first_inside_index_to_candidates_[candidate->sub_graph_->first_inside_index_].emplace_back(
48 candidate);
49 }
50 size_ = candidates.size();
51}
52
53void CandidatePartitionIndex::EstimateAllCosts(
54 const CostEstimator cost_estimator, const std::shared_ptr<CandidateFunctionCache>& cache) {
55 size_t n = 0;
56 for (PostDfsIndex index = 0; index < dataflow_graph_->size(); ++index) {
57 for (const auto& candidate : first_inside_index_to_candidates_[index]) {
58 LOG(INFO) << "Estimating cost of candidate " << candidate->ToSummary(*dataflow_graph_) << " ["
59 << n++ << "/" << size_ << "]";
60 // Cost will be cached in candidate as a side effect.
61 Cost cost = candidate->EstimatedCost(*dataflow_graph_, cost_estimator, cache);
62 LOG(INFO) << "Candidate has cost " << cost.ToString();
63 }
64 }
65}
66
67std::string CandidatePartitionIndex::ToSummary() const {
68 std::vector<std::string> lines;
69 for (const auto& candidates : first_inside_index_to_candidates_) {
70 for (const auto& candidate : candidates) {
71 if (candidate->partition_spec_name() == kHostSpecName) {
72 continue;
73 }
74 lines.emplace_back(candidate->ToSummary(*dataflow_graph_));
75 }
76 }
77 std::sort(lines.begin(), lines.end());
78 std::ostringstream os;
79 bool first = true;
80 for (const auto& line : lines) {
81 if (first) {
82 first = false;
83 } else {
84 os << std::endl;
85 }
86 os << line;
87 }
88 return os.str();
89}
90
91bool CandidatePartitionIndex::IsCompatibleWithVirtualDevice(const CandidatePartition& candidate) {
92 for (PostDfsIndex index : candidate->sub_graph_->inside_) {
93 const ExprNode* sub_expr_node = dataflow_graph_->index_to_node(index)->node_ref_;
94 if (sub_expr_node->IsInstance<OpNode>() || sub_expr_node->IsInstance<ConstructorNode>()) {
95 // These nodes are target/device polymorphic.
96 continue;
97 }
98 auto itr = virtual_devices_->find(sub_expr_node);
99 ICHECK(itr != virtual_devices_->end()) << PrettyPrint(GetRef<Expr>(sub_expr_node));
100 const Target& existing_target = itr->second->target;
101 if (!existing_target.defined()) {
102 // No constraint.
103 continue;
104 }
105 if (StructuralEqual()(existing_target, candidate->target())) {
106 // No disagreement.
107 continue;
108 }
109 if (!candidate->target().IsExternalCodegenFor(itr->second->target)) {
110 // The candidate's target is not an external codegen target compatible with the existing
111 // target.
112 // TODO(mbs): There's a conflict here between Collage's desire to leave some expression nodes
113 // 'behind' on the VM and PlanDevice's desire to assign a primitive Target to every node.
114 // I think PlanDevices is the one that needs to give here by leaving such nodes
115 // unconstrained.
116 VLOG(1) << "Ignoring candidate " << candidate->ToString()
117 << " since incompatible with existing virtual device assignment of:" << std::endl
118 << itr->second << std::endl
119 << "to sub-graph:" << std::endl
120 << PrettyPrint(GetRef<Expr>(sub_expr_node));
121 return false;
122 }
123 }
124 return true;
125}
126
127std::vector<CandidatePartition> CandidatePartitionIndex::Collect(
128 const Array<PartitionSpec>& partition_specs) {
129 VLOG_CONTEXT << "collecting";
130 std::vector<CandidatePartition> result;
131 for (const auto& spec : partition_specs) {
132 VLOG_CONTEXT << "spec " << spec->spec_name_;
133 VLOG(1) << "collecting candidates";
134 std::vector<CandidatePartition> candidates = spec->AllCandidates(*dataflow_graph_);
135 for (auto& candidate : candidates) {
136 if (!IsCompatibleWithVirtualDevice(candidate)) {
137 continue;
138 }
139 result.push_back(candidate);
140 }
141 }
142 VLOG(1) << "Found " << result.size() << " candidates";
143 return result;
144}
145
146} // namespace collage
147} // namespace relay
148} // namespace tvm
149