1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20#include "../module_equality.h"
21#include "../utils.h"
22
23#define TVM_META_SCHEDULE_CHECK_PROB_RANGE(p, name) \
24 CHECK(0.0 <= (p) && (p) <= 1.0) << "ValueError: name should be within [0, 1], " \
25 << "but get `" << #p << " = " << (p) << '\'';
26
27namespace tvm {
28namespace meta_schedule {
29
30using tir::Schedule;
31
32/**************** Data Structure ****************/
33
34/*! \brief An auxiliary data structure to help deduplicate IRModules */
35class IRModuleSet {
36 public:
37 explicit IRModuleSet(const ModuleEquality& mod_eq)
38 : tab_(/*bucket_count*/ 0, ItemHash(), ItemEqual(mod_eq)) {}
39
40 /*! \brief Add an IRModule to the set */
41 void Add(const IRModule& mod, size_t shash) { tab_.insert(Item{mod, shash}); }
42 /*! \brief Check if the IRModule is in the set */
43 bool Has(const IRModule& mod, size_t shash) const { return tab_.count(Item{mod, shash}); }
44
45 private:
46 struct Item {
47 IRModule mod;
48 size_t shash;
49 };
50 struct ItemHash {
51 size_t operator()(const Item& hash) const { return hash.shash; }
52 };
53 struct ItemEqual {
54 explicit ItemEqual(const ModuleEquality& mod_eq) : mod_eq_(mod_eq) {}
55 ItemEqual& operator=(const ItemEqual& other) { return *this; }
56
57 bool operator()(const Item& lhs, const Item& rhs) const {
58 return lhs.shash == rhs.shash && mod_eq_.Equal(lhs.mod, rhs.mod);
59 }
60
61 const ModuleEquality& mod_eq_;
62 };
63
64 std::unordered_set<Item, ItemHash, ItemEqual> tab_;
65};
66
67/*!
68 * \brief A heap with a size up-limit. If overflow happens, it evicted the worst items.
69 * \note It maintains a min heap in terms of `Item::score`. Therefore, when
70 * overflow happens, the element evicted is the one with the min `Item::score`.
71 * As time goes, the elements in the heap are going to be larger.
72 */
73class SizedHeap {
74 public:
75 struct Item {
76 Schedule sch;
77 double score;
78 bool operator<(const Item& other) const { return score > other.score; }
79 };
80
81 /*!
82 * \brief Constructor
83 * \param size_limit The up-limit of the heap size
84 */
85 explicit SizedHeap(int size_limit) : size_limit(size_limit) { heap.reserve(size_limit); }
86
87 /*!
88 * \brief Push the specific item to the heap if its key did not appears in the heap
89 * \param item The item to be pushed
90 */
91 void Push(Schedule sch, double score) {
92 int size = heap.size();
93 if (size < size_limit) {
94 // Heap is not full, just push
95 heap.emplace_back(Item{sch, score});
96 std::push_heap(heap.begin(), heap.end());
97 } else if (score > heap.front().score) {
98 // if the item is better than the worst one in the heap, we can safely kick it out
99 std::pop_heap(heap.begin(), heap.end());
100 heap.back() = {sch, score};
101 std::push_heap(heap.begin(), heap.end());
102 }
103 // Otherwise, the item is worse than any other element in the heap
104 }
105
106 /*! \brief Up-limit of the heap size */
107 int size_limit;
108 /*! \brief The heap, the worse the topper */
109 std::vector<Item> heap;
110};
111
112struct PerThreadData {
113 IRModule mod{nullptr};
114 TRandState rand_state{-1};
115 std::function<int32_t()> trace_sampler = nullptr;
116 std::function<Optional<Mutator>()> mutator_sampler = nullptr;
117
118 /*!
119 * \brief Set the value for the trace and mutator samplers per thread.
120 * \param scores The predicted score for the given samples.
121 * \param genetic_mutate_prob The probability of mutation.
122 * \param mutator_probs The probability of each mutator as a dict.
123 */
124 void Set(const std::vector<double>& scores, double genetic_mutate_prob,
125 const Map<Mutator, FloatImm>& mutator_probs) {
126 trace_sampler = tir::MakeMultinomialSampler(&rand_state, scores);
127 mutator_sampler = MakeMutatorSampler(genetic_mutate_prob, mutator_probs, &rand_state);
128 }
129
130 private:
131 /*!
132 * \brief Create a sampler function that picks mutators according to the mass function
133 * \param rand_state The random state for sampling
134 * \return The sampler created
135 */
136 static std::function<Optional<Mutator>()> MakeMutatorSampler(
137 double genetic_mutate_prob, //
138 const Map<Mutator, FloatImm>& mutator_probs, //
139 TRandState* rand_state) {
140 std::vector<Optional<Mutator>> mutators;
141 std::vector<double> masses;
142 mutators.push_back(NullOpt);
143 masses.push_back(1.0 - genetic_mutate_prob);
144 double total_mass_mutator = 0.0;
145 if (genetic_mutate_prob > 0) {
146 for (const auto& kv : mutator_probs) {
147 Mutator mutator = kv.first;
148 double mass = kv.second->value;
149 total_mass_mutator += mass;
150 mutators.push_back(mutator);
151 masses.push_back(mass * genetic_mutate_prob);
152 }
153 }
154 // Normalize the sum to 1.0
155 if (total_mass_mutator == 0.0) {
156 masses[0] = 1.0;
157 for (int i = 1, n = masses.size(); i < n; ++i) {
158 masses[i] = 0.0;
159 }
160 } else if (total_mass_mutator != 1.0) {
161 for (int i = 1, n = masses.size(); i < n; ++i) {
162 masses[i] /= total_mass_mutator;
163 }
164 }
165 return [idx_sampler = tir::MakeMultinomialSampler(rand_state, masses),
166 mutators = std::move(mutators)]() -> Optional<Mutator> {
167 int i = idx_sampler();
168 return mutators[i];
169 };
170 }
171};
172
173struct ConcurrentBitmask {
174 /*! The bit width. */
175 static constexpr const int kBitWidth = 64;
176 /*! \brief The size of the concurrent bitmask. */
177 int size;
178 /*! \brief The bitmasks. */
179 std::vector<uint64_t> bitmask;
180 /*! \brief The mutexes, one per kBitWidth(64 here) bitmasks. */
181 std::vector<std::mutex> mutexes;
182
183 /*!
184 * \brief Constructor
185 * \param n The total slots managed by the concurrent bitmask.
186 */
187 explicit ConcurrentBitmask(int n)
188 : size((n + kBitWidth - 1) / kBitWidth), bitmask(size, 0), mutexes(size) {}
189 /*!
190 * \brief Query and mark the given index if not visited before.
191 * \param x The index to concurrently check if used. If not, mark as used.
192 * \return Whether the index has been used before.
193 */
194 bool QueryAndMark(int x) {
195 constexpr uint64_t one = 1;
196 std::unique_lock<std::mutex> lock(mutexes[x / kBitWidth]);
197 if (bitmask[x / kBitWidth] & (one << (x % kBitWidth))) {
198 return false;
199 } else {
200 bitmask[x / kBitWidth] |= one << (x % kBitWidth);
201 return true;
202 }
203 }
204};
205
206/**************** Util Functions ****************/
207
208/*!
209 * \brief Assemble measure candidates from the given candidate traces.
210 * \param traces The picked candidate traces.
211 * \return The assembled measure candidates.
212 */
213Array<MeasureCandidate> AssembleCandidates(const std::vector<Schedule>& picks) {
214 Array<MeasureCandidate> measure_inputs;
215 measure_inputs.reserve(picks.size());
216 for (const Schedule& sch : picks) {
217 measure_inputs.push_back(
218 MeasureCandidate(sch, ArgInfo::FromEntryFunc(sch->mod(), /*remove_preproc=*/true)));
219 }
220 return measure_inputs;
221}
222
223/*!
224 * \brief Predict the normalized score of each candidate.
225 * \param candidates The candidates for prediction
226 * \param task The search task
227 * \param space The search space
228 * \return The normalized score in the prediction
229 */
230std::vector<double> PredictNormalizedScore(const std::vector<Schedule>& candidates,
231 const TuneContext& context,
232 const CostModel& cost_model) {
233 auto _ = Profiler::TimedScope("EvoSearch/Evolve/PredictNormalizedScore");
234 ICHECK(!candidates.empty()) << "Candidates given for score prediction can not be empty list!";
235 std::vector<double> scores = cost_model->Predict(context, AssembleCandidates(candidates));
236 for (double& score : scores) {
237 score = std::max(0.0, score);
238 }
239 return scores;
240}
241
242/**************** Evolutionary Search ****************/
243
244/*!\brief A search strategy that generates measure candidates using evolutionary search. */
245class EvolutionarySearchNode : public SearchStrategyNode {
246 public:
247 /*! \brief The state of the search strategy. */
248 struct State {
249 /*! \brief The search strategy itself */
250 EvolutionarySearchNode* self;
251 /*! \brief The number of total trials. */
252 int max_trials;
253 /*! \brief The number of trials per iteration. */
254 int num_trials_per_iter;
255 /*! \brief `[st, ed)` are the indices of the next batch of candidates. */
256 int st;
257 /*! \brief `[st, ed)` are the indices of the next batch of candidates. */
258 int ed;
259 /*! \brief The counter of returning empty results. */
260 int num_empty_iters;
261 /*! \brief The design spaces. Decisions are not used so traces only. */
262 Array<tir::Trace> design_spaces;
263 /*! \brief Pre thread data including module to be tuned and random state. */
264 std::vector<PerThreadData> per_thread_data_;
265 /*!
266 * \brief The workloads that are already measured.
267 * TODO(junrushao1994): add records from the database to avoid re-measuring.
268 * */
269 IRModuleSet measured_workloads_;
270 /*! \brief A Database for selecting useful candidates. */
271 Database database_{nullptr};
272 /*! \brief A cost model helping to explore the search space */
273 CostModel cost_model_{nullptr};
274 /*! \brief The token registered for the given workload in database. */
275 Workload token_{nullptr};
276
277 explicit State(EvolutionarySearchNode* self, int max_trials, int num_trials_per_iter,
278 Array<Schedule> design_space_schedules, Database database, CostModel cost_model)
279 : self(self),
280 max_trials(max_trials),
281 num_trials_per_iter(num_trials_per_iter),
282 st(0),
283 ed(num_trials_per_iter),
284 num_empty_iters(0),
285 measured_workloads_(database->GetModuleEquality()) {
286 design_spaces.reserve(design_spaces.size());
287 for (const Schedule& space : design_space_schedules) {
288 design_spaces.push_back(space->trace().value()->Simplified(true));
289 }
290 const TuneContextNode* ctx = self->ctx_;
291 IRModule mod = ctx->mod.value();
292 this->per_thread_data_.resize(ctx->num_threads);
293 for (PerThreadData& data : this->per_thread_data_) {
294 data.mod = DeepCopyIRModule(mod);
295 data.rand_state = ForkSeed(&self->rand_state_);
296 }
297 this->database_ = database;
298 this->cost_model_ = cost_model;
299 this->token_ = database->CommitWorkload(mod);
300 }
301
302 /*!
303 * \brief Pick up best candidates from database.
304 * \param num The number of traces to produce.
305 * \return The picked best candidates.
306 */
307 inline std::vector<Schedule> PickBestFromDatabase(int num);
308 /*!
309 * \brief Sample the initial population from previous measured results and randomly generated
310 * traces via trace replaying.
311 * \param num The number of traces to produce.
312 * \return The initial population of traces sampled.
313 */
314 inline std::vector<Schedule> SampleInitPopulation(int num);
315 /*!
316 * \brief Evolve the initial population using mutators and samplers.
317 * \param population The initial population of traces sampled.
318 * \param num The number of traces to produce.
319 * \return The evolved traces from initial population.
320 */
321 inline std::vector<Schedule> EvolveWithCostModel(std::vector<Schedule> population, int num);
322 /*!
323 * \brief Pick final candidates from the given initial population and bests of evolved ones.
324 * \param inits The initial population of traces sampled.
325 * \param bests The best candidates predicted from evolved traces.
326 * \param num The number of traces to produce.
327 * \return The final picked candidates with a ratio of both.
328 */
329 inline std::vector<Schedule> PickWithEpsGreedy(const std::vector<Schedule>& inits,
330 const std::vector<Schedule>& bests, int num);
331 /*! \brief An interface method to be called by it's counterpart in EvolutionarySearchNode */
332 inline Optional<Array<MeasureCandidate>> GenerateMeasureCandidates();
333 /*! \brief An interface method to be called by it's counterpart in EvolutionarySearchNode */
334 inline void NotifyRunnerResults(const Array<MeasureCandidate>& measure_candidates,
335 const Array<RunnerResult>& results);
336 /*!
337 * \brief Compute the hash for the given module.
338 * \param mod The input TIR module.
339 * \return The calculated hash.
340 */
341 inline size_t ModuleHash(const IRModule& mod) const;
342 };
343
344 /*! \brief The tuning context of the evolutionary search strategy. */
345 const TuneContextNode* ctx_{nullptr};
346 /*! \brief The postprocessors */
347 Array<Postproc> postprocs_;
348 /*! \brief The mutators and their probability. */
349 Map<Mutator, FloatImm> mutator_probs_;
350 /*! \brief The random state. To be initialized with TuneContext. */
351 TRandState rand_state_;
352 /*! \brief The state of the search strategy. */
353 std::unique_ptr<State> state_ = nullptr;
354
355 /*** Configuration: global ***/
356 /*! \brief The population size in the evolutionary search. */
357 int population_size;
358 /*!
359 * \brief The maximum number of iterations before early stopping to confirm the search space is
360 * exhausted
361 */
362 int num_empty_iters_before_early_stop;
363 /*** Configuration: the initial population ***/
364 /*! \brief The ratio of measured states used in the initial population */
365 double init_measured_ratio;
366 /*! \brief The minimal size of unmeasured population in the initial sampling.*/
367 int init_min_unmeasured;
368 /*! \brief The maximum number of failure during initial sampling. */
369 int max_fail_count;
370 /*** Configuration: evolution ***/
371 /*! \brief The number of iterations performed by generic algorithm. */
372 int genetic_num_iters;
373 /*! \brief The probability to perform mutation */
374 double genetic_mutate_prob;
375 /*! \brief The maximum number to try evolving the given trace. */
376 int genetic_max_fail_count;
377 /*** Configuration: pick states for measurement ***/
378 /*! \brief The ratio of measurements to use randomly sampled states. */
379 double eps_greedy;
380
381 void VisitAttrs(tvm::AttrVisitor* v) {
382 // `context_` is not visited
383 // `rand_state_` is not visited
384 // `state_` is not visited
385
386 /*** Configuration: global ***/
387 v->Visit("population_size", &population_size);
388 v->Visit("num_empty_iters_before_early_stop", &num_empty_iters_before_early_stop);
389 /*** Configuration: the initial population ***/
390 v->Visit("init_measured_ratio", &init_measured_ratio);
391 v->Visit("init_min_unmeasured", &init_min_unmeasured);
392 v->Visit("max_fail_count", &max_fail_count);
393 /*** Configuration: evolution ***/
394 v->Visit("genetic_num_iters", &genetic_num_iters);
395 v->Visit("genetic_mutate_prob", &genetic_mutate_prob);
396 v->Visit("genetic_max_fail_count", &genetic_max_fail_count);
397 /*** Configuration: pick states for measurement ***/
398 v->Visit("eps_greedy", &eps_greedy);
399 }
400
401 static constexpr const char* _type_key = "meta_schedule.EvolutionarySearch";
402 TVM_DECLARE_FINAL_OBJECT_INFO(EvolutionarySearchNode, SearchStrategyNode);
403
404 void InitializeWithTuneContext(const TuneContext& ctx) final {
405 CHECK(ctx->num_threads > 0) << "ValueError: `TuneContext.num_threads` must be > 0";
406 CHECK(ctx->space_generator.defined())
407 << "ValueError: `TuneContext.space_generator` must be defined";
408 CHECK(ctx->space_generator.value()->postprocs.defined())
409 << "ValueError: `TuneContext.space_generator.postprocs` must be defined";
410 CHECK(ctx->space_generator.value()->mutator_probs.defined())
411 << "ValueError: `TuneContext.space_generator.mutator_probs` must be defined";
412 this->ctx_ = ctx.get();
413 this->postprocs_ = ctx->space_generator.value()->postprocs.value();
414 this->mutator_probs_ = ctx->space_generator.value()->mutator_probs.value();
415 this->rand_state_ = ForkSeed(&ctx->rand_state);
416 this->state_.reset();
417 }
418
419 void PreTuning(int max_trials, int num_trials_per_iter, const Array<Schedule>& design_spaces,
420 const Optional<Database>& database, const Optional<CostModel>& cost_model) final {
421 ICHECK(!design_spaces.empty());
422 CHECK(this->ctx_ != nullptr) << "ValueError: Did you forget to initialize the TuneContext?";
423 CHECK(database.defined())
424 << "ValueError: Database is not supplied in PreTuning. Evolutionary"
425 "search algorithm requires a database to be present, so that it "
426 "could sample from previously-explored population. If you do not "
427 "intent to store data on disk, please use `tvm.meta_schedule.database.MemoryDatabase`";
428 CHECK(cost_model.defined())
429 << "ValueError: CostModel is not supplied in PreTuning. Evolutionary search "
430 "algorithm expects a cost model to filter out potentially less efficient kernels. If "
431 "you do not expect a cost model to help, please use "
432 "`tvm.meta_schedule.cost_model.RandomModel`";
433 CHECK(this->state_ == nullptr)
434 << "ValueError: `PreTuning` is already invoked without corresponding `PostTuning`.";
435 this->state_ = std::make_unique<State>(this, max_trials, num_trials_per_iter, design_spaces,
436 database.value(), cost_model.value());
437 }
438
439 void PostTuning() final {
440 CHECK(this->state_ != nullptr) << "ValueError: `PostTuning` is invoked without corresponding "
441 "`PreTuning`, or `PostTuning` is already invoked.";
442 this->state_.reset();
443 }
444
445 Optional<Array<MeasureCandidate>> GenerateMeasureCandidates() final {
446 ICHECK(this->state_ != nullptr);
447 return this->state_->GenerateMeasureCandidates();
448 }
449
450 void NotifyRunnerResults(const Array<MeasureCandidate>& measure_candidates,
451 const Array<RunnerResult>& results) final {
452 ICHECK(this->state_ != nullptr);
453 this->state_->NotifyRunnerResults(measure_candidates, results);
454 }
455
456 SearchStrategy Clone() const final {
457 ObjectPtr<EvolutionarySearchNode> n = make_object<EvolutionarySearchNode>();
458 n->population_size = this->population_size;
459 n->num_empty_iters_before_early_stop = this->num_empty_iters_before_early_stop;
460 n->init_measured_ratio = this->init_measured_ratio;
461 n->init_min_unmeasured = this->init_min_unmeasured;
462 n->max_fail_count = this->max_fail_count;
463 n->genetic_num_iters = this->genetic_num_iters;
464 n->genetic_mutate_prob = this->genetic_mutate_prob;
465 n->genetic_max_fail_count = this->genetic_max_fail_count;
466 n->eps_greedy = this->eps_greedy;
467 n->ctx_ = this->ctx_;
468 n->rand_state_ = this->rand_state_;
469 n->state_ = nullptr; // cleared the state
470 return SearchStrategy(n);
471 }
472};
473
474std::vector<Schedule> EvolutionarySearchNode::State::PickBestFromDatabase(int num) {
475 auto _ = Profiler::TimedScope("EvoSearch/PickBestFromDatabase");
476 std::vector<tir::Trace> measured_traces;
477 measured_traces.reserve(num);
478 Array<TuningRecord> top_records = this->database_->GetTopK(this->token_, num);
479 for (TuningRecord record : top_records) {
480 measured_traces.push_back(record->trace);
481 }
482 int actual_num = measured_traces.size();
483 ThreadedTraceApply pp(self->postprocs_);
484 std::vector<Schedule> results(actual_num, Schedule{nullptr});
485 auto f_proc_measured = [this, &measured_traces, &results, &pp](int thread_id,
486 int trace_id) -> void {
487 PerThreadData& data = this->per_thread_data_.at(thread_id);
488 TRandState* rand_state = &data.rand_state;
489 const IRModule& mod = data.mod;
490 tir::Trace trace = measured_traces.at(trace_id);
491 Schedule& result = results.at(trace_id);
492 ICHECK(!result.defined());
493 if (Optional<Schedule> sch = pp.Apply(mod, trace, rand_state)) {
494 result = sch.value();
495 } else {
496 LOG(FATAL) << "ValueError: Cannot postprocess the trace:\n" << trace;
497 throw;
498 }
499 };
500 support::parallel_for_dynamic(0, actual_num, self->ctx_->num_threads, f_proc_measured);
501 return results;
502}
503
504std::vector<Schedule> EvolutionarySearchNode::State::SampleInitPopulation(int num) {
505 auto _ = Profiler::TimedScope("EvoSearch/SampleInitPopulation");
506 ThreadedTraceApply pp(self->postprocs_);
507 std::vector<Schedule> out_schs;
508 int fail_count = 0;
509 while (static_cast<int>(out_schs.size()) < self->init_min_unmeasured &&
510 fail_count < self->max_fail_count) {
511 std::vector<Schedule> results(num, Schedule{nullptr});
512 auto f_proc_unmeasured = [this, &results, &pp](int thread_id, int trace_id) -> void {
513 PerThreadData& data = this->per_thread_data_.at(thread_id);
514 TRandState* rand_state = &data.rand_state;
515 const IRModule& mod = data.mod;
516 Schedule& result = results.at(trace_id);
517 ICHECK(!result.defined());
518 int design_space_index = tir::SampleInt(rand_state, 0, design_spaces.size());
519 tir::Trace trace(design_spaces[design_space_index]->insts, {});
520 if (Optional<Schedule> sch = pp.Apply(mod, trace, rand_state)) {
521 result = sch.value();
522 }
523 };
524 support::parallel_for_dynamic(0, num, self->ctx_->num_threads, f_proc_unmeasured);
525 bool found_new = false;
526 for (int i = 0; i < num; i++) {
527 if (results[i].defined()) {
528 found_new = true;
529 out_schs.push_back(results[i]);
530 }
531 }
532 fail_count += !found_new;
533 TVM_PY_LOG(INFO, self->ctx_->logger) << "Sample-Init-Population summary:\n"
534 << pp.SummarizeFailures();
535 }
536 return out_schs;
537}
538
539std::vector<Schedule> EvolutionarySearchNode::State::EvolveWithCostModel(
540 std::vector<Schedule> population, int num) {
541 IRModuleSet exists(database_->GetModuleEquality());
542 {
543 auto _ = Profiler::TimedScope("EvoSearch/Evolve/Misc/CopyMeasuredWorkloads");
544 ICHECK_GT(num, 0);
545 // The heap to record best schedule, we do not consider schedules that are already measured
546 exists = this->measured_workloads_;
547 }
548 SizedHeap heap(num);
549 for (int iter = 0;; ++iter) {
550 // Predict normalized score with the cost model,
551 std::vector<double> scores =
552 PredictNormalizedScore(population, GetRef<TuneContext>(self->ctx_), this->cost_model_);
553
554 {
555 auto _ = Profiler::TimedScope("EvoSearch/Evolve/Misc");
556 ICHECK_EQ(scores.size(), population.size());
557 for (int i = 0, n = population.size(); i < n; ++i) {
558 Schedule sch = population.at(i);
559 IRModule mod = sch->mod();
560 size_t shash = ModuleHash(mod);
561 double score = scores.at(i);
562 if (!exists.Has(mod, shash)) {
563 exists.Add(mod, shash);
564 heap.Push(sch, score);
565 }
566 }
567 // Discontinue once it reaches end of search
568 if (iter == self->genetic_num_iters) {
569 break;
570 }
571 // Set threaded samplers, with probability from predicated normalized throughput
572 for (PerThreadData& data : this->per_thread_data_) {
573 data.Set(scores, self->genetic_mutate_prob, self->mutator_probs_);
574 }
575 }
576 {
577 auto _ = Profiler::TimedScope("EvoSearch/Evolve/Mutation");
578 ThreadedTraceApply pp(self->postprocs_);
579 ConcurrentBitmask cbmask(self->population_size);
580 std::vector<Schedule> next_population(self->population_size, Schedule{nullptr});
581 // The worker function
582 auto f_find_candidate = [&cbmask, &population, &next_population, &pp, this](int thread_id,
583 int trace_id) {
584 // Prepare samplers
585 PerThreadData& data = this->per_thread_data_.at(thread_id);
586 TRandState* rand_state = &data.rand_state;
587 const IRModule& mod = data.mod;
588 std::function<int()>& trace_sampler = data.trace_sampler;
589 std::function<Optional<Mutator>()>& mutator_sampler = data.mutator_sampler;
590 Schedule& result = next_population.at(trace_id);
591 int sampled_trace_id = -1;
592 // Loop until success
593 for (int fail_count = 0; fail_count <= self->genetic_max_fail_count; ++fail_count) {
594 sampled_trace_id = trace_sampler();
595 tir::Trace trace = population.at(sampled_trace_id)->trace().value();
596 if (Optional<Mutator> opt_mutator = mutator_sampler()) {
597 // Decision: mutate
598 Mutator mutator = opt_mutator.value();
599 if (Optional<tir::Trace> new_trace = mutator->Apply(trace, rand_state)) {
600 if (Optional<Schedule> sch = pp.Apply(mod, new_trace.value(), rand_state)) {
601 // note that sch's trace is different from new_trace
602 // because it contains post-processing information
603 result = sch.value();
604 break;
605 }
606 }
607 } else if (cbmask.QueryAndMark(sampled_trace_id)) {
608 // Decision: do not mutate
609 break;
610 }
611 }
612 // if retry count exceeds the limit, reuse an old sample
613 if (!result.defined()) {
614 result = population.at(sampled_trace_id);
615 }
616 };
617 support::parallel_for_dynamic(0, self->population_size, self->ctx_->num_threads,
618 f_find_candidate);
619
620 population.swap(next_population);
621 TVM_PY_LOG(INFO, self->ctx_->logger) << "Evolve iter #" << iter << " done. Summary:\n"
622 << pp.SummarizeFailures();
623 }
624 }
625 // Return the best states from the heap, sorting from higher score to lower ones
626 {
627 auto _ = Profiler::TimedScope("EvoSearch/Evolve/Misc");
628 std::sort(heap.heap.begin(), heap.heap.end());
629 std::vector<Schedule> results;
630 results.reserve(num);
631 for (const SizedHeap::Item& item : heap.heap) {
632 results.push_back(item.sch);
633 }
634
635 constexpr int kNumScoresPerLine = 16;
636 std::ostringstream os;
637 int n = heap.heap.size();
638 for (int st = 0; st < n; st += kNumScoresPerLine) {
639 os << std::endl;
640 int ed = std::min(st + kNumScoresPerLine, n);
641 os << "[" << (st + 1) << " : " << ed << "]:\t";
642 for (int i = st; i < ed; ++i) {
643 if (i != st) {
644 os << " ";
645 }
646 os << std::fixed << std::setprecision(4) << heap.heap.at(i).score;
647 }
648 }
649 TVM_PY_LOG(INFO, self->ctx_->logger)
650 << "Scores of the best " << n << " candidates:" << os.str();
651 return results;
652 }
653}
654
655std::vector<Schedule> EvolutionarySearchNode::State::PickWithEpsGreedy(
656 const std::vector<Schedule>& unmeasured, const std::vector<Schedule>& bests, int num) {
657 auto _ = Profiler::TimedScope("EvoSearch/PickWithEpsGreedy");
658 int num_rands = num * self->eps_greedy;
659 int num_bests = num - num_rands;
660 std::vector<int> rands =
661 tir::SampleWithoutReplacement(&self->rand_state_, unmeasured.size(), unmeasured.size());
662 std::vector<Schedule> results;
663 results.reserve(num);
664 IRModuleSet& measured_workloads = this->measured_workloads_;
665 for (int i = 0, i_bests = 0, i_rands = 0; i < num; ++i) {
666 bool has_best = i_bests < static_cast<int>(bests.size());
667 bool has_rand = i_rands < static_cast<int>(rands.size());
668 // Pick a schedule
669 Schedule sch{nullptr};
670 // If needs `bests`, then prefer `bests`
671 if (i < num_bests) {
672 if (has_best) {
673 sch = bests[i_bests++];
674 } else if (has_rand) {
675 sch = unmeasured[rands[i_rands++]];
676 } else {
677 break;
678 }
679 } else {
680 // Else prefer `rands`
681 if (has_rand) {
682 sch = unmeasured[rands[i_rands++]];
683 } else if (has_best) {
684 sch = bests[i_bests++];
685 } else {
686 break;
687 }
688 }
689 IRModule mod = sch->mod();
690 size_t shash = ModuleHash(mod);
691 if (!measured_workloads.Has(mod, shash)) {
692 measured_workloads.Add(mod, shash);
693 results.push_back(sch);
694 }
695 }
696 return results;
697}
698
699Optional<Array<MeasureCandidate>> EvolutionarySearchNode::State::GenerateMeasureCandidates() {
700 if (st >= max_trials) {
701 return NullOpt;
702 }
703 int sample_num = num_trials_per_iter;
704 if (ed > max_trials) {
705 sample_num = max_trials - st;
706 ed = max_trials;
707 }
708 ICHECK_LT(st, ed);
709 int pop = self->population_size;
710 std::vector<Schedule> inits;
711 inits.reserve(pop);
712
713 TVM_PY_LOG(INFO, self->ctx_->logger) << "Generating candidates......";
714 std::vector<Schedule> measured = PickBestFromDatabase(pop * self->init_measured_ratio);
715 TVM_PY_LOG(INFO, self->ctx_->logger)
716 << "Picked top " << measured.size() << " candidate(s) from database";
717 std::vector<Schedule> unmeasured = SampleInitPopulation(pop - measured.size());
718 if (static_cast<int>(unmeasured.size()) < self->init_min_unmeasured) {
719 TVM_PY_LOG(WARNING, self->ctx_->logger)
720 << "Cannot sample enough initial population, evolutionary search failed.";
721 return NullOpt;
722 }
723 TVM_PY_LOG(INFO, self->ctx_->logger) << "Sampled " << unmeasured.size() << " candidate(s)";
724 inits.insert(inits.end(), measured.begin(), measured.end());
725 inits.insert(inits.end(), unmeasured.begin(), unmeasured.end());
726 std::vector<Schedule> bests = EvolveWithCostModel(inits, sample_num);
727 TVM_PY_LOG(INFO, self->ctx_->logger)
728 << "Got " << bests.size() << " candidate(s) with evolutionary search";
729 std::vector<Schedule> picks = PickWithEpsGreedy(unmeasured, bests, sample_num);
730 TVM_PY_LOG(INFO, self->ctx_->logger)
731 << "Sending " << picks.size() << " candidates(s) for measurement";
732 if (picks.empty()) {
733 ++this->num_empty_iters;
734 if (this->num_empty_iters >= self->num_empty_iters_before_early_stop) {
735 return NullOpt;
736 }
737 }
738 return AssembleCandidates(picks);
739}
740
741void EvolutionarySearchNode::State::NotifyRunnerResults(
742 const Array<MeasureCandidate>& measure_candidates, const Array<RunnerResult>& results) {
743 st += results.size();
744 ed += results.size();
745}
746
747size_t EvolutionarySearchNode::State::ModuleHash(const IRModule& mod) const {
748 return database_->GetModuleEquality().Hash(mod);
749}
750
751SearchStrategy SearchStrategy::EvolutionarySearch(int population_size, //
752 double init_measured_ratio, //
753 int init_min_unmeasured, //
754 int max_fail_count, //
755 int genetic_num_iters, //
756 double genetic_mutate_prob, //
757 int genetic_max_fail_count, //
758 double eps_greedy) {
759 TVM_META_SCHEDULE_CHECK_PROB_RANGE(init_measured_ratio, "Initial measured ratio");
760 TVM_META_SCHEDULE_CHECK_PROB_RANGE(genetic_mutate_prob, "Mutation probability");
761 TVM_META_SCHEDULE_CHECK_PROB_RANGE(eps_greedy, "Greedy pick probability");
762 ObjectPtr<EvolutionarySearchNode> n = make_object<EvolutionarySearchNode>();
763 n->population_size = population_size;
764 n->num_empty_iters_before_early_stop = 5;
765 n->init_measured_ratio = init_measured_ratio;
766 n->init_min_unmeasured = init_min_unmeasured;
767 n->max_fail_count = max_fail_count;
768 n->genetic_num_iters = genetic_num_iters;
769 n->genetic_max_fail_count = genetic_max_fail_count;
770 n->genetic_mutate_prob = genetic_mutate_prob;
771 n->eps_greedy = eps_greedy;
772 return SearchStrategy(n);
773}
774
775class EvolutionarySearch : public SearchStrategy {
776 public:
777 TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(EvolutionarySearch, SearchStrategy,
778 EvolutionarySearchNode);
779};
780
781Array<Schedule> EvolutionarySearchSampleInitPopulation(EvolutionarySearch self, int num) {
782 std::vector<Schedule> results = self->state_->SampleInitPopulation(num);
783 return Array<Schedule>(results.begin(), results.end());
784}
785
786Array<Schedule> EvolutionarySearchEvolveWithCostModel(EvolutionarySearch self,
787 Array<Schedule> population, int num) {
788 Array<Schedule> result;
789 std::vector<Schedule> population_vec =
790 std::vector<Schedule>(population.begin(), population.end());
791 std::vector<Schedule> schs = self->state_->EvolveWithCostModel(population_vec, num);
792 for (Schedule sch : schs) {
793 IRModule mod = sch->mod();
794 size_t shash = self->state_->ModuleHash(mod);
795 if (!self->state_->measured_workloads_.Has(mod, shash)) {
796 self->state_->measured_workloads_.Add(mod, shash);
797 result.push_back(sch);
798 }
799 }
800 return result;
801}
802
803TVM_REGISTER_NODE_TYPE(EvolutionarySearchNode);
804TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyEvolutionarySearch")
805 .set_body_typed(SearchStrategy::EvolutionarySearch);
806TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyEvolutionarySearchSampleInitPopulation")
807 .set_body_typed(EvolutionarySearchSampleInitPopulation);
808TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyEvolutionarySearchEvolveWithCostModel")
809 .set_body_typed(EvolutionarySearchEvolveWithCostModel);
810
811} // namespace meta_schedule
812} // namespace tvm
813