1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/collage/candidate_partition.cc
22 * \brief A potential partition in the Collage search.
23 */
24
25#ifndef TVM_RELAY_COLLAGE_CANDIDATE_PARTITION_H_
26#define TVM_RELAY_COLLAGE_CANDIDATE_PARTITION_H_
27
28#include <tvm/runtime/container/string.h>
29#include <tvm/target/compilation_config.h>
30
31#include <memory>
32#include <string>
33#include <vector>
34
35#include "./candidate_function_cache.h"
36#include "./cost.h"
37#include "./cost_estimator.h"
38#include "./name_supply.h"
39#include "./sub_graph.h"
40
41namespace tvm {
42namespace relay {
43namespace collage {
44
45class PartitionSpec;
46
47/*!
48 * \brief A candidate partition w.r.t. the overall Relay model.
49 *
50 * We represent the partition as a sub-graph. This means not only can we represent the scope
51 * of Relay sub-expressions intended for a particular partition (or kernel), but we can also
52 * represent various conventions for encoding how the operators within the partition should be
53 * tagged for downstream processing.
54 */
55class CandidatePartitionNode : public Object {
56 public:
57 CandidatePartitionNode() = default;
58
59 /*!
60 * \brief Combination of all the partition rule names which produced this candidate.
61 * For debugging and explainability.
62 */
63 String rule_name_;
64
65 /*!
66 * \brief The sub-graph of the overall expression matched by the partition rule.
67 */
68 SubGraph sub_graph_;
69
70 /*!
71 * \brief The partition specification which produced this candidate.
72 */
73 ObjectRef /* actually PartitionSpec */ spec_;
74
75 /*!
76 * \brief The (cached) cost of the partition.
77 *
78 * Initially Cost::Unknown, calculated and cached by EstimateCost.
79 */
80 mutable Cost cost_ = Cost::Unknown();
81
82 void VisitAttrs(AttrVisitor* v);
83
84 /*!
85 * \brief Returns the partition specification which produced this candidate.
86 */
87 PartitionSpec partition_spec() const;
88
89 /*!
90 * \brief Returns the name of the partition specification which produced this candidate.
91 */
92 std::string partition_spec_name() const;
93
94 /*!
95 * \brief Returns the target of the partition specification which produced this candidate.
96 */
97 Target target() const;
98
99 /*!
100 * \brief Return the estimated cost of the candidate partition, using \p cost_estimator and
101 * \p cache.
102 */
103 Cost EstimatedCost(const DataflowGraph& dataflow_graph, const CostEstimator& cost_estimator,
104 const std::shared_ptr<CandidateFunctionCache>& cache) const;
105
106 /*!
107 * \brief Returns a brief description of candidate suitable for debugging output.
108 */
109 std::string ToSummary(const DataflowGraph& dataflow_graph) const;
110
111 std::string ToString() const;
112
113 static constexpr const char* _type_key = "relay.collage.CandidatePartition";
114 TVM_DECLARE_FINAL_OBJECT_INFO(CandidatePartitionNode, Object);
115};
116
117class CandidatePartition : public ObjectRef {
118 public:
119 CandidatePartition(String rule_name, SubGraph sub_graph,
120 ObjectRef /* actually PartitionSpec */ spec, Cost cost = Cost::Unknown());
121
122 bool operator<(const CandidatePartition& that) const;
123
124 /*!
125 * \brief Returns true if this and \p that candidate are disjoint, have the same (or no) target,
126 * and touch. This does not imply the \p DisjointUnion of this and that will be valid. For
127 * example, the result may be too deep or have too many outputs.
128 */
129 bool AreTouching(const DataflowGraph& dataflow_graph, const CandidatePartition& that) const;
130
131 /*!
132 * \brief Returns the disjoint union of this and \p that.
133 */
134 CandidatePartition DisjointUnion(const DataflowGraph& dataflow_graph,
135 const CandidatePartition& that) const;
136
137 /*!
138 * \brief Returns the disjoint union of all \p candidates.
139 */
140 static CandidatePartition DisjointUnion(const DataflowGraph& dataflow_graph,
141 std::vector<CandidatePartition> candidates);
142
143 /*!
144 * \brief Returns the root expression of \p dataflow_graph rewritten to apply all the partitions
145 * implied by \p candidates. The candidates can be in any order but must be disjoint.
146 */
147 static Expr ParallelRewrite(const DataflowGraph& dataflow_graph,
148 const std::vector<CandidatePartition>& candidates);
149
150 /*!
151 * Eagerly merge all touching candidates for the same target. The candidates must be disjoint
152 * and have their Targets filled in. This is typically called on the optimal list of candidate
153 * partitions found by the Collage search in order to remove unnecessary partition boundaries.
154 * Ideally the search would never produce such candidates however to keep the search space
155 * manageable Collage may only consider candidate partitions up to a particular depth.
156 */
157 static std::vector<CandidatePartition> MaxCoalesce(const DataflowGraph& dataflow_graph,
158 std::vector<CandidatePartition> candidates);
159
160 TVM_DEFINE_OBJECT_REF_METHODS(CandidatePartition, ObjectRef, CandidatePartitionNode);
161 TVM_DEFINE_OBJECT_REF_COW_METHOD(CandidatePartitionNode);
162};
163
164CandidatePartition WithRuleName(CandidatePartition candidate, String rule_name);
165CandidatePartition WithTarget(CandidatePartition candidate, Target target);
166CandidatePartition WithSubGraph(CandidatePartition candidate, SubGraph sub_graph);
167
168struct CandidatePartitionHash {
169 size_t operator()(const CandidatePartition& candidate) const {
170 return candidate->sub_graph_->hash();
171 }
172};
173
174struct CandidatePartitionEquals {
175 bool operator()(const CandidatePartition& left, const CandidatePartition& right) const {
176 return *left->sub_graph_.get() == *right->sub_graph_.get();
177 }
178};
179
180struct CandidatePartitionCompare {
181 bool operator()(const CandidatePartition& left, const CandidatePartition& right) const {
182 return *left->sub_graph_.get() < *right->sub_graph_.get();
183 }
184};
185
186} // namespace collage
187} // namespace relay
188} // namespace tvm
189
190#endif // TVM_RELAY_COLLAGE_CANDIDATE_PARTITION_H_
191