1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#ifndef TVM_META_SCHEDULE_SEARCH_STRATEGY_H_
20#define TVM_META_SCHEDULE_SEARCH_STRATEGY_H_
21
22#include <tvm/meta_schedule/arg_info.h>
23#include <tvm/meta_schedule/cost_model.h>
24#include <tvm/meta_schedule/database.h>
25#include <tvm/meta_schedule/measure_candidate.h>
26#include <tvm/meta_schedule/runner.h>
27#include <tvm/node/reflection.h>
28#include <tvm/runtime/container/array.h>
29#include <tvm/runtime/container/optional.h>
30#include <tvm/runtime/object.h>
31#include <tvm/runtime/packed_func.h>
32#include <tvm/tir/schedule/schedule.h>
33
34namespace tvm {
35namespace meta_schedule {
36
37// Forward declaration
38class TuneContext;
39class SearchStrategy;
40
41/*!
42 * \brief The search strategy for measure candidates generation.
43 * \note The relationship between SearchStrategy and other classes are as follows:
44 ┌──────────────────────────────────────────────────────────────┐
45 ┌──┴───────────────────────────────────────────────────────────┐ │
46┌──┴────────────────── Tune Context ───────────────────────────┐ │ │
47│ ┌─────────────────────┐ │ │ │
48│ │ │ Generate │ │ │
49│ │ Space Generator ├──────────────┐ │ │ │
50│ │ │ │ │ │ │
51│ └─────────────────────┘ ▼ │ │ │
52│ Design Space │ │ │
53│ ┌─────────────────────┐ │ │ │ │
54│ Generate │ │ Pretuning │ │ │ │
55│ ┌───────────┤ Search Strategy │◄─────────────┘ │ │ │
56│ │ │ │ │ ├──┘
57│ │ └─────────────────────┘ ├──┘
58└────┼─────────────────────────────────────────────────────────┘
59
60
61┌────┼──────────────── Managed By Task Scheduler ─────────────────────┐
62│ │ ┌───────────┐ │
63│ │ Send to │ │ Send to │
64│ ▼ ┌─────────────►│ Builder ├──────────┐ │
65│ Measure Candidate │ Builder │ │ Runner │ │
66│ │ │ └───────────┘ │ │
67│ │ ┌────────────┴────────┐ │ │
68│ │ │ │ ┌───────────┐ │ │
69│ └────►│ Task Scheduler │ │ │ │ │
70│ │ │ │ Runner │◄─────────┘ │
71│ └─────────────────────┘ │ │ │
72│ ▲ └─────┬─────┘ │
73│ │ │ │
74│ └─── Runner Future ◄────┘ │
75└─────────────────────────────────────────────────────────────────────┘
76*/
77class SearchStrategyNode : public runtime::Object {
78 public:
79 /*! \brief Virtual destructor */
80 virtual ~SearchStrategyNode() = default;
81
82 /*!
83 * \brief Initialize the search strategy with tuning context.
84 * \param context The tuning context for initialization.
85 * \note This method is supposed to be called only once before every other method.
86 */
87 virtual void InitializeWithTuneContext(const TuneContext& context) = 0;
88
89 /*!
90 * \brief Pre-tuning for the search strategy.
91 * \param max_trials The maximum number of trials.
92 * \param num_trials_per_iter The number of trials per iteration.
93 * \param design_spaces The design spaces used during tuning process.
94 * \param database The database used during tuning process.
95 * \param cost_model The cost model used during tuning process.
96 * \note Pre-tuning is supposed to be called before the tuning process and after the
97 * initialization. Because the search strategy is stateful, we can always call pretuning
98 * and reset the search strategy.
99 */
100 virtual void PreTuning(int max_trials, int num_trials_per_iter,
101 const Array<tir::Schedule>& design_spaces,
102 const Optional<Database>& database,
103 const Optional<CostModel>& cost_model) = 0;
104
105 /*!
106 * \brief Post-tuning for the search strategy.
107 * \note Post-tuning is supposed to be called after the tuning process and before we reset the
108 * search strategy with another pre-tuning. Post-tuning can be empty.
109 */
110 virtual void PostTuning() = 0;
111
112 /*!
113 * \brief Generate measure candidates from design spaces for measurement.
114 * \return The measure candidates generated, nullptr if finished.
115 */
116 virtual Optional<Array<MeasureCandidate>> GenerateMeasureCandidates() = 0;
117
118 /*!
119 * \brief Update the search strategy with measurement results.
120 * \param measure_candidates The candidates to be measured.
121 * \param results The measurement results from the runner.
122 */
123 virtual void NotifyRunnerResults(const Array<MeasureCandidate>& measure_candidates,
124 const Array<RunnerResult>& results) = 0;
125
126 /*!
127 * \brief Clone the search strategy.
128 * \return The cloned search strategy.
129 */
130 virtual SearchStrategy Clone() const = 0;
131
132 static constexpr const char* _type_key = "meta_schedule.SearchStrategy";
133 TVM_DECLARE_BASE_OBJECT_INFO(SearchStrategyNode, Object);
134};
135
136/*!
137 * \brief Managed reference to SearchStrategyNode.
138 * \sa SearchStrategyNode
139 */
140class SearchStrategy : public runtime::ObjectRef {
141 public:
142 /*!
143 * \brief The function type of `InitializeWithTuneContext` method.
144 * \param context The tuning context for initialization.
145 */
146 using FInitializeWithTuneContext = runtime::TypedPackedFunc<void(const TuneContext&)>;
147 /*!
148 * \brief The function type of `PreTuning` method.
149 */
150 using FPreTuning = runtime::TypedPackedFunc<void(
151 int max_trials, int num_trials_per_iter, const Array<tir::Schedule>&,
152 const Optional<Database>&, const Optional<CostModel>&)>;
153 /*! \brief The function type of `PostTuning` method. */
154 using FPostTuning = runtime::TypedPackedFunc<void()>;
155 /*!
156 * \brief The function type of `GenerateMeasureCandidates` method.
157 * \return The measure candidates generated, nullptr if finished.
158 */
159 using FGenerateMeasureCandidates = runtime::TypedPackedFunc<Optional<Array<MeasureCandidate>>()>;
160 /*!
161 * \brief The function type of `NotifyRunnerResults` method.
162 * \param results The measurement results from the runner.
163 */
164 using FNotifyRunnerResults =
165 runtime::TypedPackedFunc<void(const Array<MeasureCandidate>&, const Array<RunnerResult>&)>;
166 /*!
167 * \brief The function type of `Clone` method.
168 * \return The cloned search strategy.
169 */
170 using FClone = runtime::TypedPackedFunc<SearchStrategy()>;
171 /*!
172 * \brief Create a search strategy with customized methods on the python-side.
173 * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
174 * \param f_pre_tuning The packed function of `PreTuning`.
175 * \param f_post_tuning The packed function of `PostTuning`.
176 * \param f_generate_measure_candidates The packed function of `GenerateMeasureCandidates`.
177 * \param f_notify_runner_results The packed function of `NotifyRunnerResults`.
178 * \param f_clone The packed function of `Clone`.
179 * \return The search strategy created.
180 */
181 TVM_DLL static SearchStrategy PySearchStrategy(
182 FInitializeWithTuneContext f_initialize_with_tune_context, //
183 FPreTuning f_pre_tuning, //
184 FPostTuning f_post_tuning, //
185 FGenerateMeasureCandidates f_generate_measure_candidates, //
186 FNotifyRunnerResults f_notify_runner_results, //
187 FClone f_clone);
188
189 /*!
190 * \brief Constructor of replay trace search strategy.
191 * \param max_fail_count The max number of failures during trace replaying.
192 */
193 TVM_DLL static SearchStrategy ReplayTrace(int max_fail_count);
194
195 /*! \brief Constructor of replay func search strategy. */
196 TVM_DLL static SearchStrategy ReplayFunc();
197
198 /*!
199 * \brief Constructor of evolutionary search strategy.
200 * \param population_size The initial sample population.
201 * \param init_measured_ratio The ratio of measures samples in initial population.
202 * \param init_min_unmeasured The minimal size of unmeasured population in the initial sampling.
203 * \param max_fail_count The max number of failure during initial sampling.
204 * \param genetic_num_iters The iterations to run the genetic algorithm.
205 * \param genetic_mutate_prob The probability of mutation.
206 * \param genetic_max_fail_count The maximum number to try evolving the given trace.
207 * \param eps_greedy The ratio to select samples in a greedy fashion via their predicted score.
208 */
209 TVM_DLL static SearchStrategy EvolutionarySearch(int population_size, //
210 double init_measured_ratio, //
211 int init_min_unmeasured, //
212 int max_fail_count, //
213 int genetic_num_iters, //
214 double genetic_mutate_prob, //
215 int genetic_max_fail_count, //
216 double eps_greedy);
217
218 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchStrategy, ObjectRef, SearchStrategyNode);
219};
220
221/*! \brief The python side customizable class for measure candidate generation */
222class PySearchStrategyNode : public SearchStrategyNode {
223 public:
224 using FInitializeWithTuneContext = SearchStrategy::FInitializeWithTuneContext;
225 using FPreTuning = SearchStrategy::FPreTuning;
226 using FPostTuning = SearchStrategy::FPostTuning;
227 using FGenerateMeasureCandidates = SearchStrategy::FGenerateMeasureCandidates;
228 using FNotifyRunnerResults = SearchStrategy::FNotifyRunnerResults;
229 using FClone = SearchStrategy::FClone;
230
231 /*! \brief The packed function to the `InitializeWithTuneContext` method. */
232 FInitializeWithTuneContext f_initialize_with_tune_context;
233 /*! \brief The packed function to the `PreTuning` method. */
234 FPreTuning f_pre_tuning;
235 /*! \brief The packed function to the `PostTuning` method. */
236 FPostTuning f_post_tuning;
237 /*! \brief The packed function to the `GenerateMeasureCandidates` method. */
238 FGenerateMeasureCandidates f_generate_measure_candidates;
239 /*! \brief The packed function to the `NotifyRunnerResults` method. */
240 FNotifyRunnerResults f_notify_runner_results;
241 /*! \brief The packed function to the `Clone` method. */
242 FClone f_clone;
243
244 void VisitAttrs(tvm::AttrVisitor* v) {
245 // `f_initialize_with_tune_context` is not visited
246 // `f_pre_tuning` is not visited
247 // `f_post_tuning` is not visited
248 // `f_generate_measure_candidates` is not visited
249 // `f_notify_runner_results` is not visited
250 // `f_clone` is not visited
251 }
252
253 void InitializeWithTuneContext(const TuneContext& context) final;
254 void PreTuning(int max_trials, int num_trials_per_iter, const Array<tir::Schedule>& design_spaces,
255 const Optional<Database>& database, const Optional<CostModel>& cost_model) final;
256 void PostTuning() final;
257 Optional<Array<MeasureCandidate>> GenerateMeasureCandidates() final;
258 void NotifyRunnerResults(const Array<MeasureCandidate>& measure_candidates,
259 const Array<RunnerResult>& results);
260 SearchStrategy Clone() const final;
261
262 static constexpr const char* _type_key = "meta_schedule.PySearchStrategy";
263 TVM_DECLARE_FINAL_OBJECT_INFO(PySearchStrategyNode, SearchStrategyNode);
264};
265
266} // namespace meta_schedule
267} // namespace tvm
268
269#endif // TVM_META_SCHEDULE_SEARCH_STRATEGY_H_
270