1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #ifndef TVM_META_SCHEDULE_SEARCH_STRATEGY_H_ |
20 | #define TVM_META_SCHEDULE_SEARCH_STRATEGY_H_ |
21 | |
22 | #include <tvm/meta_schedule/arg_info.h> |
23 | #include <tvm/meta_schedule/cost_model.h> |
24 | #include <tvm/meta_schedule/database.h> |
25 | #include <tvm/meta_schedule/measure_candidate.h> |
26 | #include <tvm/meta_schedule/runner.h> |
27 | #include <tvm/node/reflection.h> |
28 | #include <tvm/runtime/container/array.h> |
29 | #include <tvm/runtime/container/optional.h> |
30 | #include <tvm/runtime/object.h> |
31 | #include <tvm/runtime/packed_func.h> |
32 | #include <tvm/tir/schedule/schedule.h> |
33 | |
34 | namespace tvm { |
35 | namespace meta_schedule { |
36 | |
37 | // Forward declaration |
38 | class TuneContext; |
39 | class SearchStrategy; |
40 | |
41 | /*! |
42 | * \brief The search strategy for measure candidates generation. |
43 | * \note The relationship between SearchStrategy and other classes are as follows: |
44 | ┌──────────────────────────────────────────────────────────────┐ |
45 | ┌──┴───────────────────────────────────────────────────────────┐ │ |
46 | ┌──┴────────────────── Tune Context ───────────────────────────┐ │ │ |
47 | │ ┌─────────────────────┐ │ │ │ |
48 | │ │ │ Generate │ │ │ |
49 | │ │ Space Generator ├──────────────┐ │ │ │ |
50 | │ │ │ │ │ │ │ |
51 | │ └─────────────────────┘ ▼ │ │ │ |
52 | │ Design Space │ │ │ |
53 | │ ┌─────────────────────┐ │ │ │ │ |
54 | │ Generate │ │ Pretuning │ │ │ │ |
55 | │ ┌───────────┤ Search Strategy │◄─────────────┘ │ │ │ |
56 | │ │ │ │ │ ├──┘ |
57 | │ │ └─────────────────────┘ ├──┘ |
58 | └────┼─────────────────────────────────────────────────────────┘ |
59 | │ |
60 | │ |
61 | ┌────┼──────────────── Managed By Task Scheduler ─────────────────────┐ |
62 | │ │ ┌───────────┐ │ |
63 | │ │ Send to │ │ Send to │ |
64 | │ ▼ ┌─────────────►│ Builder ├──────────┐ │ |
65 | │ Measure Candidate │ Builder │ │ Runner │ │ |
66 | │ │ │ └───────────┘ │ │ |
67 | │ │ ┌────────────┴────────┐ │ │ |
68 | │ │ │ │ ┌───────────┐ │ │ |
69 | │ └────►│ Task Scheduler │ │ │ │ │ |
70 | │ │ │ │ Runner │◄─────────┘ │ |
71 | │ └─────────────────────┘ │ │ │ |
72 | │ ▲ └─────┬─────┘ │ |
73 | │ │ │ │ |
74 | │ └─── Runner Future ◄────┘ │ |
75 | └─────────────────────────────────────────────────────────────────────┘ |
76 | */ |
77 | class SearchStrategyNode : public runtime::Object { |
78 | public: |
79 | /*! \brief Virtual destructor */ |
80 | virtual ~SearchStrategyNode() = default; |
81 | |
82 | /*! |
83 | * \brief Initialize the search strategy with tuning context. |
84 | * \param context The tuning context for initialization. |
85 | * \note This method is supposed to be called only once before every other method. |
86 | */ |
87 | virtual void InitializeWithTuneContext(const TuneContext& context) = 0; |
88 | |
89 | /*! |
90 | * \brief Pre-tuning for the search strategy. |
91 | * \param max_trials The maximum number of trials. |
92 | * \param num_trials_per_iter The number of trials per iteration. |
93 | * \param design_spaces The design spaces used during tuning process. |
94 | * \param database The database used during tuning process. |
95 | * \param cost_model The cost model used during tuning process. |
96 | * \note Pre-tuning is supposed to be called before the tuning process and after the |
97 | * initialization. Because the search strategy is stateful, we can always call pretuning |
98 | * and reset the search strategy. |
99 | */ |
100 | virtual void PreTuning(int max_trials, int num_trials_per_iter, |
101 | const Array<tir::Schedule>& design_spaces, |
102 | const Optional<Database>& database, |
103 | const Optional<CostModel>& cost_model) = 0; |
104 | |
105 | /*! |
106 | * \brief Post-tuning for the search strategy. |
107 | * \note Post-tuning is supposed to be called after the tuning process and before we reset the |
108 | * search strategy with another pre-tuning. Post-tuning can be empty. |
109 | */ |
110 | virtual void PostTuning() = 0; |
111 | |
112 | /*! |
113 | * \brief Generate measure candidates from design spaces for measurement. |
114 | * \return The measure candidates generated, nullptr if finished. |
115 | */ |
116 | virtual Optional<Array<MeasureCandidate>> GenerateMeasureCandidates() = 0; |
117 | |
118 | /*! |
119 | * \brief Update the search strategy with measurement results. |
120 | * \param measure_candidates The candidates to be measured. |
121 | * \param results The measurement results from the runner. |
122 | */ |
123 | virtual void NotifyRunnerResults(const Array<MeasureCandidate>& measure_candidates, |
124 | const Array<RunnerResult>& results) = 0; |
125 | |
126 | /*! |
127 | * \brief Clone the search strategy. |
128 | * \return The cloned search strategy. |
129 | */ |
130 | virtual SearchStrategy Clone() const = 0; |
131 | |
132 | static constexpr const char* _type_key = "meta_schedule.SearchStrategy" ; |
133 | TVM_DECLARE_BASE_OBJECT_INFO(SearchStrategyNode, Object); |
134 | }; |
135 | |
136 | /*! |
137 | * \brief Managed reference to SearchStrategyNode. |
138 | * \sa SearchStrategyNode |
139 | */ |
140 | class SearchStrategy : public runtime::ObjectRef { |
141 | public: |
142 | /*! |
143 | * \brief The function type of `InitializeWithTuneContext` method. |
144 | * \param context The tuning context for initialization. |
145 | */ |
146 | using FInitializeWithTuneContext = runtime::TypedPackedFunc<void(const TuneContext&)>; |
147 | /*! |
148 | * \brief The function type of `PreTuning` method. |
149 | */ |
150 | using FPreTuning = runtime::TypedPackedFunc<void( |
151 | int max_trials, int num_trials_per_iter, const Array<tir::Schedule>&, |
152 | const Optional<Database>&, const Optional<CostModel>&)>; |
153 | /*! \brief The function type of `PostTuning` method. */ |
154 | using FPostTuning = runtime::TypedPackedFunc<void()>; |
155 | /*! |
156 | * \brief The function type of `GenerateMeasureCandidates` method. |
157 | * \return The measure candidates generated, nullptr if finished. |
158 | */ |
159 | using FGenerateMeasureCandidates = runtime::TypedPackedFunc<Optional<Array<MeasureCandidate>>()>; |
160 | /*! |
161 | * \brief The function type of `NotifyRunnerResults` method. |
162 | * \param results The measurement results from the runner. |
163 | */ |
164 | using FNotifyRunnerResults = |
165 | runtime::TypedPackedFunc<void(const Array<MeasureCandidate>&, const Array<RunnerResult>&)>; |
166 | /*! |
167 | * \brief The function type of `Clone` method. |
168 | * \return The cloned search strategy. |
169 | */ |
170 | using FClone = runtime::TypedPackedFunc<SearchStrategy()>; |
171 | /*! |
172 | * \brief Create a search strategy with customized methods on the python-side. |
173 | * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. |
174 | * \param f_pre_tuning The packed function of `PreTuning`. |
175 | * \param f_post_tuning The packed function of `PostTuning`. |
176 | * \param f_generate_measure_candidates The packed function of `GenerateMeasureCandidates`. |
177 | * \param f_notify_runner_results The packed function of `NotifyRunnerResults`. |
178 | * \param f_clone The packed function of `Clone`. |
179 | * \return The search strategy created. |
180 | */ |
181 | TVM_DLL static SearchStrategy PySearchStrategy( |
182 | FInitializeWithTuneContext f_initialize_with_tune_context, // |
183 | FPreTuning f_pre_tuning, // |
184 | FPostTuning f_post_tuning, // |
185 | FGenerateMeasureCandidates f_generate_measure_candidates, // |
186 | FNotifyRunnerResults f_notify_runner_results, // |
187 | FClone f_clone); |
188 | |
189 | /*! |
190 | * \brief Constructor of replay trace search strategy. |
191 | * \param max_fail_count The max number of failures during trace replaying. |
192 | */ |
193 | TVM_DLL static SearchStrategy ReplayTrace(int max_fail_count); |
194 | |
195 | /*! \brief Constructor of replay func search strategy. */ |
196 | TVM_DLL static SearchStrategy ReplayFunc(); |
197 | |
198 | /*! |
199 | * \brief Constructor of evolutionary search strategy. |
200 | * \param population_size The initial sample population. |
201 | * \param init_measured_ratio The ratio of measures samples in initial population. |
202 | * \param init_min_unmeasured The minimal size of unmeasured population in the initial sampling. |
203 | * \param max_fail_count The max number of failure during initial sampling. |
204 | * \param genetic_num_iters The iterations to run the genetic algorithm. |
205 | * \param genetic_mutate_prob The probability of mutation. |
206 | * \param genetic_max_fail_count The maximum number to try evolving the given trace. |
207 | * \param eps_greedy The ratio to select samples in a greedy fashion via their predicted score. |
208 | */ |
209 | TVM_DLL static SearchStrategy EvolutionarySearch(int population_size, // |
210 | double init_measured_ratio, // |
211 | int init_min_unmeasured, // |
212 | int max_fail_count, // |
213 | int genetic_num_iters, // |
214 | double genetic_mutate_prob, // |
215 | int genetic_max_fail_count, // |
216 | double eps_greedy); |
217 | |
218 | TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchStrategy, ObjectRef, SearchStrategyNode); |
219 | }; |
220 | |
221 | /*! \brief The python side customizable class for measure candidate generation */ |
222 | class PySearchStrategyNode : public SearchStrategyNode { |
223 | public: |
224 | using FInitializeWithTuneContext = SearchStrategy::FInitializeWithTuneContext; |
225 | using FPreTuning = SearchStrategy::FPreTuning; |
226 | using FPostTuning = SearchStrategy::FPostTuning; |
227 | using FGenerateMeasureCandidates = SearchStrategy::FGenerateMeasureCandidates; |
228 | using FNotifyRunnerResults = SearchStrategy::FNotifyRunnerResults; |
229 | using FClone = SearchStrategy::FClone; |
230 | |
231 | /*! \brief The packed function to the `InitializeWithTuneContext` method. */ |
232 | FInitializeWithTuneContext f_initialize_with_tune_context; |
233 | /*! \brief The packed function to the `PreTuning` method. */ |
234 | FPreTuning f_pre_tuning; |
235 | /*! \brief The packed function to the `PostTuning` method. */ |
236 | FPostTuning f_post_tuning; |
237 | /*! \brief The packed function to the `GenerateMeasureCandidates` method. */ |
238 | FGenerateMeasureCandidates f_generate_measure_candidates; |
239 | /*! \brief The packed function to the `NotifyRunnerResults` method. */ |
240 | FNotifyRunnerResults f_notify_runner_results; |
241 | /*! \brief The packed function to the `Clone` method. */ |
242 | FClone f_clone; |
243 | |
244 | void VisitAttrs(tvm::AttrVisitor* v) { |
245 | // `f_initialize_with_tune_context` is not visited |
246 | // `f_pre_tuning` is not visited |
247 | // `f_post_tuning` is not visited |
248 | // `f_generate_measure_candidates` is not visited |
249 | // `f_notify_runner_results` is not visited |
250 | // `f_clone` is not visited |
251 | } |
252 | |
253 | void InitializeWithTuneContext(const TuneContext& context) final; |
254 | void PreTuning(int max_trials, int num_trials_per_iter, const Array<tir::Schedule>& design_spaces, |
255 | const Optional<Database>& database, const Optional<CostModel>& cost_model) final; |
256 | void PostTuning() final; |
257 | Optional<Array<MeasureCandidate>> GenerateMeasureCandidates() final; |
258 | void NotifyRunnerResults(const Array<MeasureCandidate>& measure_candidates, |
259 | const Array<RunnerResult>& results); |
260 | SearchStrategy Clone() const final; |
261 | |
262 | static constexpr const char* _type_key = "meta_schedule.PySearchStrategy" ; |
263 | TVM_DECLARE_FINAL_OBJECT_INFO(PySearchStrategyNode, SearchStrategyNode); |
264 | }; |
265 | |
266 | } // namespace meta_schedule |
267 | } // namespace tvm |
268 | |
269 | #endif // TVM_META_SCHEDULE_SEARCH_STRATEGY_H_ |
270 | |