1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/collage/candidate_partition.cc |
22 | * \brief A potential partition in the Collage search. |
23 | */ |
24 | |
25 | #ifndef TVM_RELAY_COLLAGE_CANDIDATE_PARTITION_H_ |
26 | #define TVM_RELAY_COLLAGE_CANDIDATE_PARTITION_H_ |
27 | |
28 | #include <tvm/runtime/container/string.h> |
29 | #include <tvm/target/compilation_config.h> |
30 | |
31 | #include <memory> |
32 | #include <string> |
33 | #include <vector> |
34 | |
35 | #include "./candidate_function_cache.h" |
36 | #include "./cost.h" |
37 | #include "./cost_estimator.h" |
38 | #include "./name_supply.h" |
39 | #include "./sub_graph.h" |
40 | |
41 | namespace tvm { |
42 | namespace relay { |
43 | namespace collage { |
44 | |
45 | class PartitionSpec; |
46 | |
47 | /*! |
48 | * \brief A candidate partition w.r.t. the overall Relay model. |
49 | * |
50 | * We represent the partition as a sub-graph. This means not only can we represent the scope |
51 | * of Relay sub-expressions intended for a particular partition (or kernel), but we can also |
52 | * represent various conventions for encoding how the operators within the partition should be |
53 | * tagged for downstream processing. |
54 | */ |
55 | class CandidatePartitionNode : public Object { |
56 | public: |
57 | CandidatePartitionNode() = default; |
58 | |
59 | /*! |
60 | * \brief Combination of all the partition rule names which produced this candidate. |
61 | * For debugging and explainability. |
62 | */ |
63 | String rule_name_; |
64 | |
65 | /*! |
66 | * \brief The sub-graph of the overall expression matched by the partition rule. |
67 | */ |
68 | SubGraph sub_graph_; |
69 | |
70 | /*! |
71 | * \brief The partition specification which produced this candidate. |
72 | */ |
73 | ObjectRef /* actually PartitionSpec */ spec_; |
74 | |
75 | /*! |
76 | * \brief The (cached) cost of the partition. |
77 | * |
78 | * Initially Cost::Unknown, calculated and cached by EstimateCost. |
79 | */ |
80 | mutable Cost cost_ = Cost::Unknown(); |
81 | |
82 | void VisitAttrs(AttrVisitor* v); |
83 | |
84 | /*! |
85 | * \brief Returns the partition specification which produced this candidate. |
86 | */ |
87 | PartitionSpec partition_spec() const; |
88 | |
89 | /*! |
90 | * \brief Returns the name of the partition specification which produced this candidate. |
91 | */ |
92 | std::string partition_spec_name() const; |
93 | |
94 | /*! |
95 | * \brief Returns the target of the partition specification which produced this candidate. |
96 | */ |
97 | Target target() const; |
98 | |
99 | /*! |
100 | * \brief Return the estimated cost of the candidate partition, using \p cost_estimator and |
101 | * \p cache. |
102 | */ |
103 | Cost EstimatedCost(const DataflowGraph& dataflow_graph, const CostEstimator& cost_estimator, |
104 | const std::shared_ptr<CandidateFunctionCache>& cache) const; |
105 | |
106 | /*! |
107 | * \brief Returns a brief description of candidate suitable for debugging output. |
108 | */ |
109 | std::string ToSummary(const DataflowGraph& dataflow_graph) const; |
110 | |
111 | std::string ToString() const; |
112 | |
113 | static constexpr const char* _type_key = "relay.collage.CandidatePartition" ; |
114 | TVM_DECLARE_FINAL_OBJECT_INFO(CandidatePartitionNode, Object); |
115 | }; |
116 | |
117 | class CandidatePartition : public ObjectRef { |
118 | public: |
119 | CandidatePartition(String rule_name, SubGraph sub_graph, |
120 | ObjectRef /* actually PartitionSpec */ spec, Cost cost = Cost::Unknown()); |
121 | |
122 | bool operator<(const CandidatePartition& that) const; |
123 | |
124 | /*! |
125 | * \brief Returns true if this and \p that candidate are disjoint, have the same (or no) target, |
126 | * and touch. This does not imply the \p DisjointUnion of this and that will be valid. For |
127 | * example, the result may be too deep or have too many outputs. |
128 | */ |
129 | bool AreTouching(const DataflowGraph& dataflow_graph, const CandidatePartition& that) const; |
130 | |
131 | /*! |
132 | * \brief Returns the disjoint union of this and \p that. |
133 | */ |
134 | CandidatePartition DisjointUnion(const DataflowGraph& dataflow_graph, |
135 | const CandidatePartition& that) const; |
136 | |
137 | /*! |
138 | * \brief Returns the disjoint union of all \p candidates. |
139 | */ |
140 | static CandidatePartition DisjointUnion(const DataflowGraph& dataflow_graph, |
141 | std::vector<CandidatePartition> candidates); |
142 | |
143 | /*! |
144 | * \brief Returns the root expression of \p dataflow_graph rewritten to apply all the partitions |
145 | * implied by \p candidates. The candidates can be in any order but must be disjoint. |
146 | */ |
147 | static Expr ParallelRewrite(const DataflowGraph& dataflow_graph, |
148 | const std::vector<CandidatePartition>& candidates); |
149 | |
150 | /*! |
151 | * Eagerly merge all touching candidates for the same target. The candidates must be disjoint |
152 | * and have their Targets filled in. This is typically called on the optimal list of candidate |
153 | * partitions found by the Collage search in order to remove unnecessary partition boundaries. |
154 | * Ideally the search would never produce such candidates however to keep the search space |
155 | * manageable Collage may only consider candidate partitions up to a particular depth. |
156 | */ |
157 | static std::vector<CandidatePartition> MaxCoalesce(const DataflowGraph& dataflow_graph, |
158 | std::vector<CandidatePartition> candidates); |
159 | |
160 | TVM_DEFINE_OBJECT_REF_METHODS(CandidatePartition, ObjectRef, CandidatePartitionNode); |
161 | TVM_DEFINE_OBJECT_REF_COW_METHOD(CandidatePartitionNode); |
162 | }; |
163 | |
164 | CandidatePartition WithRuleName(CandidatePartition candidate, String rule_name); |
165 | CandidatePartition WithTarget(CandidatePartition candidate, Target target); |
166 | CandidatePartition WithSubGraph(CandidatePartition candidate, SubGraph sub_graph); |
167 | |
168 | struct CandidatePartitionHash { |
169 | size_t operator()(const CandidatePartition& candidate) const { |
170 | return candidate->sub_graph_->hash(); |
171 | } |
172 | }; |
173 | |
174 | struct CandidatePartitionEquals { |
175 | bool operator()(const CandidatePartition& left, const CandidatePartition& right) const { |
176 | return *left->sub_graph_.get() == *right->sub_graph_.get(); |
177 | } |
178 | }; |
179 | |
180 | struct CandidatePartitionCompare { |
181 | bool operator()(const CandidatePartition& left, const CandidatePartition& right) const { |
182 | return *left->sub_graph_.get() < *right->sub_graph_.get(); |
183 | } |
184 | }; |
185 | |
186 | } // namespace collage |
187 | } // namespace relay |
188 | } // namespace tvm |
189 | |
190 | #endif // TVM_RELAY_COLLAGE_CANDIDATE_PARTITION_H_ |
191 | |