1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | #include "../module_equality.h" |
21 | #include "../utils.h" |
22 | |
23 | #define TVM_META_SCHEDULE_CHECK_PROB_RANGE(p, name) \ |
24 | CHECK(0.0 <= (p) && (p) <= 1.0) << "ValueError: name should be within [0, 1], " \ |
25 | << "but get `" << #p << " = " << (p) << '\''; |
26 | |
27 | namespace tvm { |
28 | namespace meta_schedule { |
29 | |
30 | using tir::Schedule; |
31 | |
32 | /**************** Data Structure ****************/ |
33 | |
34 | /*! \brief An auxiliary data structure to help deduplicate IRModules */ |
35 | class IRModuleSet { |
36 | public: |
37 | explicit IRModuleSet(const ModuleEquality& mod_eq) |
38 | : tab_(/*bucket_count*/ 0, ItemHash(), ItemEqual(mod_eq)) {} |
39 | |
40 | /*! \brief Add an IRModule to the set */ |
41 | void Add(const IRModule& mod, size_t shash) { tab_.insert(Item{mod, shash}); } |
42 | /*! \brief Check if the IRModule is in the set */ |
43 | bool Has(const IRModule& mod, size_t shash) const { return tab_.count(Item{mod, shash}); } |
44 | |
45 | private: |
46 | struct Item { |
47 | IRModule mod; |
48 | size_t shash; |
49 | }; |
50 | struct ItemHash { |
51 | size_t operator()(const Item& hash) const { return hash.shash; } |
52 | }; |
53 | struct ItemEqual { |
54 | explicit ItemEqual(const ModuleEquality& mod_eq) : mod_eq_(mod_eq) {} |
55 | ItemEqual& operator=(const ItemEqual& other) { return *this; } |
56 | |
57 | bool operator()(const Item& lhs, const Item& rhs) const { |
58 | return lhs.shash == rhs.shash && mod_eq_.Equal(lhs.mod, rhs.mod); |
59 | } |
60 | |
61 | const ModuleEquality& mod_eq_; |
62 | }; |
63 | |
64 | std::unordered_set<Item, ItemHash, ItemEqual> tab_; |
65 | }; |
66 | |
67 | /*! |
68 | * \brief A heap with a size up-limit. If overflow happens, it evicted the worst items. |
69 | * \note It maintains a min heap in terms of `Item::score`. Therefore, when |
70 | * overflow happens, the element evicted is the one with the min `Item::score`. |
71 | * As time goes, the elements in the heap are going to be larger. |
72 | */ |
73 | class SizedHeap { |
74 | public: |
75 | struct Item { |
76 | Schedule sch; |
77 | double score; |
78 | bool operator<(const Item& other) const { return score > other.score; } |
79 | }; |
80 | |
81 | /*! |
82 | * \brief Constructor |
83 | * \param size_limit The up-limit of the heap size |
84 | */ |
85 | explicit SizedHeap(int size_limit) : size_limit(size_limit) { heap.reserve(size_limit); } |
86 | |
87 | /*! |
88 | * \brief Push the specific item to the heap if its key did not appears in the heap |
89 | * \param item The item to be pushed |
90 | */ |
91 | void Push(Schedule sch, double score) { |
92 | int size = heap.size(); |
93 | if (size < size_limit) { |
94 | // Heap is not full, just push |
95 | heap.emplace_back(Item{sch, score}); |
96 | std::push_heap(heap.begin(), heap.end()); |
97 | } else if (score > heap.front().score) { |
98 | // if the item is better than the worst one in the heap, we can safely kick it out |
99 | std::pop_heap(heap.begin(), heap.end()); |
100 | heap.back() = {sch, score}; |
101 | std::push_heap(heap.begin(), heap.end()); |
102 | } |
103 | // Otherwise, the item is worse than any other element in the heap |
104 | } |
105 | |
106 | /*! \brief Up-limit of the heap size */ |
107 | int size_limit; |
108 | /*! \brief The heap, the worse the topper */ |
109 | std::vector<Item> heap; |
110 | }; |
111 | |
112 | struct PerThreadData { |
113 | IRModule mod{nullptr}; |
114 | TRandState rand_state{-1}; |
115 | std::function<int32_t()> trace_sampler = nullptr; |
116 | std::function<Optional<Mutator>()> mutator_sampler = nullptr; |
117 | |
118 | /*! |
119 | * \brief Set the value for the trace and mutator samplers per thread. |
120 | * \param scores The predicted score for the given samples. |
121 | * \param genetic_mutate_prob The probability of mutation. |
122 | * \param mutator_probs The probability of each mutator as a dict. |
123 | */ |
124 | void Set(const std::vector<double>& scores, double genetic_mutate_prob, |
125 | const Map<Mutator, FloatImm>& mutator_probs) { |
126 | trace_sampler = tir::MakeMultinomialSampler(&rand_state, scores); |
127 | mutator_sampler = MakeMutatorSampler(genetic_mutate_prob, mutator_probs, &rand_state); |
128 | } |
129 | |
130 | private: |
131 | /*! |
132 | * \brief Create a sampler function that picks mutators according to the mass function |
133 | * \param rand_state The random state for sampling |
134 | * \return The sampler created |
135 | */ |
136 | static std::function<Optional<Mutator>()> MakeMutatorSampler( |
137 | double genetic_mutate_prob, // |
138 | const Map<Mutator, FloatImm>& mutator_probs, // |
139 | TRandState* rand_state) { |
140 | std::vector<Optional<Mutator>> mutators; |
141 | std::vector<double> masses; |
142 | mutators.push_back(NullOpt); |
143 | masses.push_back(1.0 - genetic_mutate_prob); |
144 | double total_mass_mutator = 0.0; |
145 | if (genetic_mutate_prob > 0) { |
146 | for (const auto& kv : mutator_probs) { |
147 | Mutator mutator = kv.first; |
148 | double mass = kv.second->value; |
149 | total_mass_mutator += mass; |
150 | mutators.push_back(mutator); |
151 | masses.push_back(mass * genetic_mutate_prob); |
152 | } |
153 | } |
154 | // Normalize the sum to 1.0 |
155 | if (total_mass_mutator == 0.0) { |
156 | masses[0] = 1.0; |
157 | for (int i = 1, n = masses.size(); i < n; ++i) { |
158 | masses[i] = 0.0; |
159 | } |
160 | } else if (total_mass_mutator != 1.0) { |
161 | for (int i = 1, n = masses.size(); i < n; ++i) { |
162 | masses[i] /= total_mass_mutator; |
163 | } |
164 | } |
165 | return [idx_sampler = tir::MakeMultinomialSampler(rand_state, masses), |
166 | mutators = std::move(mutators)]() -> Optional<Mutator> { |
167 | int i = idx_sampler(); |
168 | return mutators[i]; |
169 | }; |
170 | } |
171 | }; |
172 | |
173 | struct ConcurrentBitmask { |
174 | /*! The bit width. */ |
175 | static constexpr const int kBitWidth = 64; |
176 | /*! \brief The size of the concurrent bitmask. */ |
177 | int size; |
178 | /*! \brief The bitmasks. */ |
179 | std::vector<uint64_t> bitmask; |
180 | /*! \brief The mutexes, one per kBitWidth(64 here) bitmasks. */ |
181 | std::vector<std::mutex> mutexes; |
182 | |
183 | /*! |
184 | * \brief Constructor |
185 | * \param n The total slots managed by the concurrent bitmask. |
186 | */ |
187 | explicit ConcurrentBitmask(int n) |
188 | : size((n + kBitWidth - 1) / kBitWidth), bitmask(size, 0), mutexes(size) {} |
189 | /*! |
190 | * \brief Query and mark the given index if not visited before. |
191 | * \param x The index to concurrently check if used. If not, mark as used. |
192 | * \return Whether the index has been used before. |
193 | */ |
194 | bool QueryAndMark(int x) { |
195 | constexpr uint64_t one = 1; |
196 | std::unique_lock<std::mutex> lock(mutexes[x / kBitWidth]); |
197 | if (bitmask[x / kBitWidth] & (one << (x % kBitWidth))) { |
198 | return false; |
199 | } else { |
200 | bitmask[x / kBitWidth] |= one << (x % kBitWidth); |
201 | return true; |
202 | } |
203 | } |
204 | }; |
205 | |
206 | /**************** Util Functions ****************/ |
207 | |
208 | /*! |
209 | * \brief Assemble measure candidates from the given candidate traces. |
210 | * \param traces The picked candidate traces. |
211 | * \return The assembled measure candidates. |
212 | */ |
213 | Array<MeasureCandidate> AssembleCandidates(const std::vector<Schedule>& picks) { |
214 | Array<MeasureCandidate> measure_inputs; |
215 | measure_inputs.reserve(picks.size()); |
216 | for (const Schedule& sch : picks) { |
217 | measure_inputs.push_back( |
218 | MeasureCandidate(sch, ArgInfo::FromEntryFunc(sch->mod(), /*remove_preproc=*/true))); |
219 | } |
220 | return measure_inputs; |
221 | } |
222 | |
223 | /*! |
224 | * \brief Predict the normalized score of each candidate. |
225 | * \param candidates The candidates for prediction |
226 | * \param task The search task |
227 | * \param space The search space |
228 | * \return The normalized score in the prediction |
229 | */ |
230 | std::vector<double> PredictNormalizedScore(const std::vector<Schedule>& candidates, |
231 | const TuneContext& context, |
232 | const CostModel& cost_model) { |
233 | auto _ = Profiler::TimedScope("EvoSearch/Evolve/PredictNormalizedScore" ); |
234 | ICHECK(!candidates.empty()) << "Candidates given for score prediction can not be empty list!" ; |
235 | std::vector<double> scores = cost_model->Predict(context, AssembleCandidates(candidates)); |
236 | for (double& score : scores) { |
237 | score = std::max(0.0, score); |
238 | } |
239 | return scores; |
240 | } |
241 | |
242 | /**************** Evolutionary Search ****************/ |
243 | |
244 | /*!\brief A search strategy that generates measure candidates using evolutionary search. */ |
245 | class EvolutionarySearchNode : public SearchStrategyNode { |
246 | public: |
247 | /*! \brief The state of the search strategy. */ |
248 | struct State { |
249 | /*! \brief The search strategy itself */ |
250 | EvolutionarySearchNode* self; |
251 | /*! \brief The number of total trials. */ |
252 | int max_trials; |
253 | /*! \brief The number of trials per iteration. */ |
254 | int num_trials_per_iter; |
255 | /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ |
256 | int st; |
257 | /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ |
258 | int ed; |
259 | /*! \brief The counter of returning empty results. */ |
260 | int num_empty_iters; |
261 | /*! \brief The design spaces. Decisions are not used so traces only. */ |
262 | Array<tir::Trace> design_spaces; |
263 | /*! \brief Pre thread data including module to be tuned and random state. */ |
264 | std::vector<PerThreadData> per_thread_data_; |
265 | /*! |
266 | * \brief The workloads that are already measured. |
267 | * TODO(junrushao1994): add records from the database to avoid re-measuring. |
268 | * */ |
269 | IRModuleSet measured_workloads_; |
270 | /*! \brief A Database for selecting useful candidates. */ |
271 | Database database_{nullptr}; |
272 | /*! \brief A cost model helping to explore the search space */ |
273 | CostModel cost_model_{nullptr}; |
274 | /*! \brief The token registered for the given workload in database. */ |
275 | Workload token_{nullptr}; |
276 | |
277 | explicit State(EvolutionarySearchNode* self, int max_trials, int num_trials_per_iter, |
278 | Array<Schedule> design_space_schedules, Database database, CostModel cost_model) |
279 | : self(self), |
280 | max_trials(max_trials), |
281 | num_trials_per_iter(num_trials_per_iter), |
282 | st(0), |
283 | ed(num_trials_per_iter), |
284 | num_empty_iters(0), |
285 | measured_workloads_(database->GetModuleEquality()) { |
286 | design_spaces.reserve(design_spaces.size()); |
287 | for (const Schedule& space : design_space_schedules) { |
288 | design_spaces.push_back(space->trace().value()->Simplified(true)); |
289 | } |
290 | const TuneContextNode* ctx = self->ctx_; |
291 | IRModule mod = ctx->mod.value(); |
292 | this->per_thread_data_.resize(ctx->num_threads); |
293 | for (PerThreadData& data : this->per_thread_data_) { |
294 | data.mod = DeepCopyIRModule(mod); |
295 | data.rand_state = ForkSeed(&self->rand_state_); |
296 | } |
297 | this->database_ = database; |
298 | this->cost_model_ = cost_model; |
299 | this->token_ = database->CommitWorkload(mod); |
300 | } |
301 | |
302 | /*! |
303 | * \brief Pick up best candidates from database. |
304 | * \param num The number of traces to produce. |
305 | * \return The picked best candidates. |
306 | */ |
307 | inline std::vector<Schedule> PickBestFromDatabase(int num); |
308 | /*! |
309 | * \brief Sample the initial population from previous measured results and randomly generated |
310 | * traces via trace replaying. |
311 | * \param num The number of traces to produce. |
312 | * \return The initial population of traces sampled. |
313 | */ |
314 | inline std::vector<Schedule> SampleInitPopulation(int num); |
315 | /*! |
316 | * \brief Evolve the initial population using mutators and samplers. |
317 | * \param population The initial population of traces sampled. |
318 | * \param num The number of traces to produce. |
319 | * \return The evolved traces from initial population. |
320 | */ |
321 | inline std::vector<Schedule> EvolveWithCostModel(std::vector<Schedule> population, int num); |
322 | /*! |
323 | * \brief Pick final candidates from the given initial population and bests of evolved ones. |
324 | * \param inits The initial population of traces sampled. |
325 | * \param bests The best candidates predicted from evolved traces. |
326 | * \param num The number of traces to produce. |
327 | * \return The final picked candidates with a ratio of both. |
328 | */ |
329 | inline std::vector<Schedule> PickWithEpsGreedy(const std::vector<Schedule>& inits, |
330 | const std::vector<Schedule>& bests, int num); |
331 | /*! \brief An interface method to be called by it's counterpart in EvolutionarySearchNode */ |
332 | inline Optional<Array<MeasureCandidate>> GenerateMeasureCandidates(); |
333 | /*! \brief An interface method to be called by it's counterpart in EvolutionarySearchNode */ |
334 | inline void NotifyRunnerResults(const Array<MeasureCandidate>& measure_candidates, |
335 | const Array<RunnerResult>& results); |
336 | /*! |
337 | * \brief Compute the hash for the given module. |
338 | * \param mod The input TIR module. |
339 | * \return The calculated hash. |
340 | */ |
341 | inline size_t ModuleHash(const IRModule& mod) const; |
342 | }; |
343 | |
344 | /*! \brief The tuning context of the evolutionary search strategy. */ |
345 | const TuneContextNode* ctx_{nullptr}; |
346 | /*! \brief The postprocessors */ |
347 | Array<Postproc> postprocs_; |
348 | /*! \brief The mutators and their probability. */ |
349 | Map<Mutator, FloatImm> mutator_probs_; |
350 | /*! \brief The random state. To be initialized with TuneContext. */ |
351 | TRandState rand_state_; |
352 | /*! \brief The state of the search strategy. */ |
353 | std::unique_ptr<State> state_ = nullptr; |
354 | |
355 | /*** Configuration: global ***/ |
356 | /*! \brief The population size in the evolutionary search. */ |
357 | int population_size; |
358 | /*! |
359 | * \brief The maximum number of iterations before early stopping to confirm the search space is |
360 | * exhausted |
361 | */ |
362 | int num_empty_iters_before_early_stop; |
363 | /*** Configuration: the initial population ***/ |
364 | /*! \brief The ratio of measured states used in the initial population */ |
365 | double init_measured_ratio; |
366 | /*! \brief The minimal size of unmeasured population in the initial sampling.*/ |
367 | int init_min_unmeasured; |
368 | /*! \brief The maximum number of failure during initial sampling. */ |
369 | int max_fail_count; |
370 | /*** Configuration: evolution ***/ |
371 | /*! \brief The number of iterations performed by generic algorithm. */ |
372 | int genetic_num_iters; |
373 | /*! \brief The probability to perform mutation */ |
374 | double genetic_mutate_prob; |
375 | /*! \brief The maximum number to try evolving the given trace. */ |
376 | int genetic_max_fail_count; |
377 | /*** Configuration: pick states for measurement ***/ |
378 | /*! \brief The ratio of measurements to use randomly sampled states. */ |
379 | double eps_greedy; |
380 | |
381 | void VisitAttrs(tvm::AttrVisitor* v) { |
382 | // `context_` is not visited |
383 | // `rand_state_` is not visited |
384 | // `state_` is not visited |
385 | |
386 | /*** Configuration: global ***/ |
387 | v->Visit("population_size" , &population_size); |
388 | v->Visit("num_empty_iters_before_early_stop" , &num_empty_iters_before_early_stop); |
389 | /*** Configuration: the initial population ***/ |
390 | v->Visit("init_measured_ratio" , &init_measured_ratio); |
391 | v->Visit("init_min_unmeasured" , &init_min_unmeasured); |
392 | v->Visit("max_fail_count" , &max_fail_count); |
393 | /*** Configuration: evolution ***/ |
394 | v->Visit("genetic_num_iters" , &genetic_num_iters); |
395 | v->Visit("genetic_mutate_prob" , &genetic_mutate_prob); |
396 | v->Visit("genetic_max_fail_count" , &genetic_max_fail_count); |
397 | /*** Configuration: pick states for measurement ***/ |
398 | v->Visit("eps_greedy" , &eps_greedy); |
399 | } |
400 | |
401 | static constexpr const char* _type_key = "meta_schedule.EvolutionarySearch" ; |
402 | TVM_DECLARE_FINAL_OBJECT_INFO(EvolutionarySearchNode, SearchStrategyNode); |
403 | |
404 | void InitializeWithTuneContext(const TuneContext& ctx) final { |
405 | CHECK(ctx->num_threads > 0) << "ValueError: `TuneContext.num_threads` must be > 0" ; |
406 | CHECK(ctx->space_generator.defined()) |
407 | << "ValueError: `TuneContext.space_generator` must be defined" ; |
408 | CHECK(ctx->space_generator.value()->postprocs.defined()) |
409 | << "ValueError: `TuneContext.space_generator.postprocs` must be defined" ; |
410 | CHECK(ctx->space_generator.value()->mutator_probs.defined()) |
411 | << "ValueError: `TuneContext.space_generator.mutator_probs` must be defined" ; |
412 | this->ctx_ = ctx.get(); |
413 | this->postprocs_ = ctx->space_generator.value()->postprocs.value(); |
414 | this->mutator_probs_ = ctx->space_generator.value()->mutator_probs.value(); |
415 | this->rand_state_ = ForkSeed(&ctx->rand_state); |
416 | this->state_.reset(); |
417 | } |
418 | |
419 | void PreTuning(int max_trials, int num_trials_per_iter, const Array<Schedule>& design_spaces, |
420 | const Optional<Database>& database, const Optional<CostModel>& cost_model) final { |
421 | ICHECK(!design_spaces.empty()); |
422 | CHECK(this->ctx_ != nullptr) << "ValueError: Did you forget to initialize the TuneContext?" ; |
423 | CHECK(database.defined()) |
424 | << "ValueError: Database is not supplied in PreTuning. Evolutionary" |
425 | "search algorithm requires a database to be present, so that it " |
426 | "could sample from previously-explored population. If you do not " |
427 | "intent to store data on disk, please use `tvm.meta_schedule.database.MemoryDatabase`" ; |
428 | CHECK(cost_model.defined()) |
429 | << "ValueError: CostModel is not supplied in PreTuning. Evolutionary search " |
430 | "algorithm expects a cost model to filter out potentially less efficient kernels. If " |
431 | "you do not expect a cost model to help, please use " |
432 | "`tvm.meta_schedule.cost_model.RandomModel`" ; |
433 | CHECK(this->state_ == nullptr) |
434 | << "ValueError: `PreTuning` is already invoked without corresponding `PostTuning`." ; |
435 | this->state_ = std::make_unique<State>(this, max_trials, num_trials_per_iter, design_spaces, |
436 | database.value(), cost_model.value()); |
437 | } |
438 | |
439 | void PostTuning() final { |
440 | CHECK(this->state_ != nullptr) << "ValueError: `PostTuning` is invoked without corresponding " |
441 | "`PreTuning`, or `PostTuning` is already invoked." ; |
442 | this->state_.reset(); |
443 | } |
444 | |
445 | Optional<Array<MeasureCandidate>> GenerateMeasureCandidates() final { |
446 | ICHECK(this->state_ != nullptr); |
447 | return this->state_->GenerateMeasureCandidates(); |
448 | } |
449 | |
450 | void NotifyRunnerResults(const Array<MeasureCandidate>& measure_candidates, |
451 | const Array<RunnerResult>& results) final { |
452 | ICHECK(this->state_ != nullptr); |
453 | this->state_->NotifyRunnerResults(measure_candidates, results); |
454 | } |
455 | |
456 | SearchStrategy Clone() const final { |
457 | ObjectPtr<EvolutionarySearchNode> n = make_object<EvolutionarySearchNode>(); |
458 | n->population_size = this->population_size; |
459 | n->num_empty_iters_before_early_stop = this->num_empty_iters_before_early_stop; |
460 | n->init_measured_ratio = this->init_measured_ratio; |
461 | n->init_min_unmeasured = this->init_min_unmeasured; |
462 | n->max_fail_count = this->max_fail_count; |
463 | n->genetic_num_iters = this->genetic_num_iters; |
464 | n->genetic_mutate_prob = this->genetic_mutate_prob; |
465 | n->genetic_max_fail_count = this->genetic_max_fail_count; |
466 | n->eps_greedy = this->eps_greedy; |
467 | n->ctx_ = this->ctx_; |
468 | n->rand_state_ = this->rand_state_; |
469 | n->state_ = nullptr; // cleared the state |
470 | return SearchStrategy(n); |
471 | } |
472 | }; |
473 | |
474 | std::vector<Schedule> EvolutionarySearchNode::State::PickBestFromDatabase(int num) { |
475 | auto _ = Profiler::TimedScope("EvoSearch/PickBestFromDatabase" ); |
476 | std::vector<tir::Trace> measured_traces; |
477 | measured_traces.reserve(num); |
478 | Array<TuningRecord> top_records = this->database_->GetTopK(this->token_, num); |
479 | for (TuningRecord record : top_records) { |
480 | measured_traces.push_back(record->trace); |
481 | } |
482 | int actual_num = measured_traces.size(); |
483 | ThreadedTraceApply pp(self->postprocs_); |
484 | std::vector<Schedule> results(actual_num, Schedule{nullptr}); |
485 | auto f_proc_measured = [this, &measured_traces, &results, &pp](int thread_id, |
486 | int trace_id) -> void { |
487 | PerThreadData& data = this->per_thread_data_.at(thread_id); |
488 | TRandState* rand_state = &data.rand_state; |
489 | const IRModule& mod = data.mod; |
490 | tir::Trace trace = measured_traces.at(trace_id); |
491 | Schedule& result = results.at(trace_id); |
492 | ICHECK(!result.defined()); |
493 | if (Optional<Schedule> sch = pp.Apply(mod, trace, rand_state)) { |
494 | result = sch.value(); |
495 | } else { |
496 | LOG(FATAL) << "ValueError: Cannot postprocess the trace:\n" << trace; |
497 | throw; |
498 | } |
499 | }; |
500 | support::parallel_for_dynamic(0, actual_num, self->ctx_->num_threads, f_proc_measured); |
501 | return results; |
502 | } |
503 | |
504 | std::vector<Schedule> EvolutionarySearchNode::State::SampleInitPopulation(int num) { |
505 | auto _ = Profiler::TimedScope("EvoSearch/SampleInitPopulation" ); |
506 | ThreadedTraceApply pp(self->postprocs_); |
507 | std::vector<Schedule> out_schs; |
508 | int fail_count = 0; |
509 | while (static_cast<int>(out_schs.size()) < self->init_min_unmeasured && |
510 | fail_count < self->max_fail_count) { |
511 | std::vector<Schedule> results(num, Schedule{nullptr}); |
512 | auto f_proc_unmeasured = [this, &results, &pp](int thread_id, int trace_id) -> void { |
513 | PerThreadData& data = this->per_thread_data_.at(thread_id); |
514 | TRandState* rand_state = &data.rand_state; |
515 | const IRModule& mod = data.mod; |
516 | Schedule& result = results.at(trace_id); |
517 | ICHECK(!result.defined()); |
518 | int design_space_index = tir::SampleInt(rand_state, 0, design_spaces.size()); |
519 | tir::Trace trace(design_spaces[design_space_index]->insts, {}); |
520 | if (Optional<Schedule> sch = pp.Apply(mod, trace, rand_state)) { |
521 | result = sch.value(); |
522 | } |
523 | }; |
524 | support::parallel_for_dynamic(0, num, self->ctx_->num_threads, f_proc_unmeasured); |
525 | bool found_new = false; |
526 | for (int i = 0; i < num; i++) { |
527 | if (results[i].defined()) { |
528 | found_new = true; |
529 | out_schs.push_back(results[i]); |
530 | } |
531 | } |
532 | fail_count += !found_new; |
533 | TVM_PY_LOG(INFO, self->ctx_->logger) << "Sample-Init-Population summary:\n" |
534 | << pp.SummarizeFailures(); |
535 | } |
536 | return out_schs; |
537 | } |
538 | |
539 | std::vector<Schedule> EvolutionarySearchNode::State::EvolveWithCostModel( |
540 | std::vector<Schedule> population, int num) { |
541 | IRModuleSet exists(database_->GetModuleEquality()); |
542 | { |
543 | auto _ = Profiler::TimedScope("EvoSearch/Evolve/Misc/CopyMeasuredWorkloads" ); |
544 | ICHECK_GT(num, 0); |
545 | // The heap to record best schedule, we do not consider schedules that are already measured |
546 | exists = this->measured_workloads_; |
547 | } |
548 | SizedHeap heap(num); |
549 | for (int iter = 0;; ++iter) { |
550 | // Predict normalized score with the cost model, |
551 | std::vector<double> scores = |
552 | PredictNormalizedScore(population, GetRef<TuneContext>(self->ctx_), this->cost_model_); |
553 | |
554 | { |
555 | auto _ = Profiler::TimedScope("EvoSearch/Evolve/Misc" ); |
556 | ICHECK_EQ(scores.size(), population.size()); |
557 | for (int i = 0, n = population.size(); i < n; ++i) { |
558 | Schedule sch = population.at(i); |
559 | IRModule mod = sch->mod(); |
560 | size_t shash = ModuleHash(mod); |
561 | double score = scores.at(i); |
562 | if (!exists.Has(mod, shash)) { |
563 | exists.Add(mod, shash); |
564 | heap.Push(sch, score); |
565 | } |
566 | } |
567 | // Discontinue once it reaches end of search |
568 | if (iter == self->genetic_num_iters) { |
569 | break; |
570 | } |
571 | // Set threaded samplers, with probability from predicated normalized throughput |
572 | for (PerThreadData& data : this->per_thread_data_) { |
573 | data.Set(scores, self->genetic_mutate_prob, self->mutator_probs_); |
574 | } |
575 | } |
576 | { |
577 | auto _ = Profiler::TimedScope("EvoSearch/Evolve/Mutation" ); |
578 | ThreadedTraceApply pp(self->postprocs_); |
579 | ConcurrentBitmask cbmask(self->population_size); |
580 | std::vector<Schedule> next_population(self->population_size, Schedule{nullptr}); |
581 | // The worker function |
582 | auto f_find_candidate = [&cbmask, &population, &next_population, &pp, this](int thread_id, |
583 | int trace_id) { |
584 | // Prepare samplers |
585 | PerThreadData& data = this->per_thread_data_.at(thread_id); |
586 | TRandState* rand_state = &data.rand_state; |
587 | const IRModule& mod = data.mod; |
588 | std::function<int()>& trace_sampler = data.trace_sampler; |
589 | std::function<Optional<Mutator>()>& mutator_sampler = data.mutator_sampler; |
590 | Schedule& result = next_population.at(trace_id); |
591 | int sampled_trace_id = -1; |
592 | // Loop until success |
593 | for (int fail_count = 0; fail_count <= self->genetic_max_fail_count; ++fail_count) { |
594 | sampled_trace_id = trace_sampler(); |
595 | tir::Trace trace = population.at(sampled_trace_id)->trace().value(); |
596 | if (Optional<Mutator> opt_mutator = mutator_sampler()) { |
597 | // Decision: mutate |
598 | Mutator mutator = opt_mutator.value(); |
599 | if (Optional<tir::Trace> new_trace = mutator->Apply(trace, rand_state)) { |
600 | if (Optional<Schedule> sch = pp.Apply(mod, new_trace.value(), rand_state)) { |
601 | // note that sch's trace is different from new_trace |
602 | // because it contains post-processing information |
603 | result = sch.value(); |
604 | break; |
605 | } |
606 | } |
607 | } else if (cbmask.QueryAndMark(sampled_trace_id)) { |
608 | // Decision: do not mutate |
609 | break; |
610 | } |
611 | } |
612 | // if retry count exceeds the limit, reuse an old sample |
613 | if (!result.defined()) { |
614 | result = population.at(sampled_trace_id); |
615 | } |
616 | }; |
617 | support::parallel_for_dynamic(0, self->population_size, self->ctx_->num_threads, |
618 | f_find_candidate); |
619 | |
620 | population.swap(next_population); |
621 | TVM_PY_LOG(INFO, self->ctx_->logger) << "Evolve iter #" << iter << " done. Summary:\n" |
622 | << pp.SummarizeFailures(); |
623 | } |
624 | } |
625 | // Return the best states from the heap, sorting from higher score to lower ones |
626 | { |
627 | auto _ = Profiler::TimedScope("EvoSearch/Evolve/Misc" ); |
628 | std::sort(heap.heap.begin(), heap.heap.end()); |
629 | std::vector<Schedule> results; |
630 | results.reserve(num); |
631 | for (const SizedHeap::Item& item : heap.heap) { |
632 | results.push_back(item.sch); |
633 | } |
634 | |
635 | constexpr int kNumScoresPerLine = 16; |
636 | std::ostringstream os; |
637 | int n = heap.heap.size(); |
638 | for (int st = 0; st < n; st += kNumScoresPerLine) { |
639 | os << std::endl; |
640 | int ed = std::min(st + kNumScoresPerLine, n); |
641 | os << "[" << (st + 1) << " : " << ed << "]:\t" ; |
642 | for (int i = st; i < ed; ++i) { |
643 | if (i != st) { |
644 | os << " " ; |
645 | } |
646 | os << std::fixed << std::setprecision(4) << heap.heap.at(i).score; |
647 | } |
648 | } |
649 | TVM_PY_LOG(INFO, self->ctx_->logger) |
650 | << "Scores of the best " << n << " candidates:" << os.str(); |
651 | return results; |
652 | } |
653 | } |
654 | |
655 | std::vector<Schedule> EvolutionarySearchNode::State::PickWithEpsGreedy( |
656 | const std::vector<Schedule>& unmeasured, const std::vector<Schedule>& bests, int num) { |
657 | auto _ = Profiler::TimedScope("EvoSearch/PickWithEpsGreedy" ); |
658 | int num_rands = num * self->eps_greedy; |
659 | int num_bests = num - num_rands; |
660 | std::vector<int> rands = |
661 | tir::SampleWithoutReplacement(&self->rand_state_, unmeasured.size(), unmeasured.size()); |
662 | std::vector<Schedule> results; |
663 | results.reserve(num); |
664 | IRModuleSet& measured_workloads = this->measured_workloads_; |
665 | for (int i = 0, i_bests = 0, i_rands = 0; i < num; ++i) { |
666 | bool has_best = i_bests < static_cast<int>(bests.size()); |
667 | bool has_rand = i_rands < static_cast<int>(rands.size()); |
668 | // Pick a schedule |
669 | Schedule sch{nullptr}; |
670 | // If needs `bests`, then prefer `bests` |
671 | if (i < num_bests) { |
672 | if (has_best) { |
673 | sch = bests[i_bests++]; |
674 | } else if (has_rand) { |
675 | sch = unmeasured[rands[i_rands++]]; |
676 | } else { |
677 | break; |
678 | } |
679 | } else { |
680 | // Else prefer `rands` |
681 | if (has_rand) { |
682 | sch = unmeasured[rands[i_rands++]]; |
683 | } else if (has_best) { |
684 | sch = bests[i_bests++]; |
685 | } else { |
686 | break; |
687 | } |
688 | } |
689 | IRModule mod = sch->mod(); |
690 | size_t shash = ModuleHash(mod); |
691 | if (!measured_workloads.Has(mod, shash)) { |
692 | measured_workloads.Add(mod, shash); |
693 | results.push_back(sch); |
694 | } |
695 | } |
696 | return results; |
697 | } |
698 | |
699 | Optional<Array<MeasureCandidate>> EvolutionarySearchNode::State::GenerateMeasureCandidates() { |
700 | if (st >= max_trials) { |
701 | return NullOpt; |
702 | } |
703 | int sample_num = num_trials_per_iter; |
704 | if (ed > max_trials) { |
705 | sample_num = max_trials - st; |
706 | ed = max_trials; |
707 | } |
708 | ICHECK_LT(st, ed); |
709 | int pop = self->population_size; |
710 | std::vector<Schedule> inits; |
711 | inits.reserve(pop); |
712 | |
713 | TVM_PY_LOG(INFO, self->ctx_->logger) << "Generating candidates......" ; |
714 | std::vector<Schedule> measured = PickBestFromDatabase(pop * self->init_measured_ratio); |
715 | TVM_PY_LOG(INFO, self->ctx_->logger) |
716 | << "Picked top " << measured.size() << " candidate(s) from database" ; |
717 | std::vector<Schedule> unmeasured = SampleInitPopulation(pop - measured.size()); |
718 | if (static_cast<int>(unmeasured.size()) < self->init_min_unmeasured) { |
719 | TVM_PY_LOG(WARNING, self->ctx_->logger) |
720 | << "Cannot sample enough initial population, evolutionary search failed." ; |
721 | return NullOpt; |
722 | } |
723 | TVM_PY_LOG(INFO, self->ctx_->logger) << "Sampled " << unmeasured.size() << " candidate(s)" ; |
724 | inits.insert(inits.end(), measured.begin(), measured.end()); |
725 | inits.insert(inits.end(), unmeasured.begin(), unmeasured.end()); |
726 | std::vector<Schedule> bests = EvolveWithCostModel(inits, sample_num); |
727 | TVM_PY_LOG(INFO, self->ctx_->logger) |
728 | << "Got " << bests.size() << " candidate(s) with evolutionary search" ; |
729 | std::vector<Schedule> picks = PickWithEpsGreedy(unmeasured, bests, sample_num); |
730 | TVM_PY_LOG(INFO, self->ctx_->logger) |
731 | << "Sending " << picks.size() << " candidates(s) for measurement" ; |
732 | if (picks.empty()) { |
733 | ++this->num_empty_iters; |
734 | if (this->num_empty_iters >= self->num_empty_iters_before_early_stop) { |
735 | return NullOpt; |
736 | } |
737 | } |
738 | return AssembleCandidates(picks); |
739 | } |
740 | |
741 | void EvolutionarySearchNode::State::NotifyRunnerResults( |
742 | const Array<MeasureCandidate>& measure_candidates, const Array<RunnerResult>& results) { |
743 | st += results.size(); |
744 | ed += results.size(); |
745 | } |
746 | |
747 | size_t EvolutionarySearchNode::State::ModuleHash(const IRModule& mod) const { |
748 | return database_->GetModuleEquality().Hash(mod); |
749 | } |
750 | |
751 | SearchStrategy SearchStrategy::EvolutionarySearch(int population_size, // |
752 | double init_measured_ratio, // |
753 | int init_min_unmeasured, // |
754 | int max_fail_count, // |
755 | int genetic_num_iters, // |
756 | double genetic_mutate_prob, // |
757 | int genetic_max_fail_count, // |
758 | double eps_greedy) { |
759 | TVM_META_SCHEDULE_CHECK_PROB_RANGE(init_measured_ratio, "Initial measured ratio" ); |
760 | TVM_META_SCHEDULE_CHECK_PROB_RANGE(genetic_mutate_prob, "Mutation probability" ); |
761 | TVM_META_SCHEDULE_CHECK_PROB_RANGE(eps_greedy, "Greedy pick probability" ); |
762 | ObjectPtr<EvolutionarySearchNode> n = make_object<EvolutionarySearchNode>(); |
763 | n->population_size = population_size; |
764 | n->num_empty_iters_before_early_stop = 5; |
765 | n->init_measured_ratio = init_measured_ratio; |
766 | n->init_min_unmeasured = init_min_unmeasured; |
767 | n->max_fail_count = max_fail_count; |
768 | n->genetic_num_iters = genetic_num_iters; |
769 | n->genetic_max_fail_count = genetic_max_fail_count; |
770 | n->genetic_mutate_prob = genetic_mutate_prob; |
771 | n->eps_greedy = eps_greedy; |
772 | return SearchStrategy(n); |
773 | } |
774 | |
775 | class EvolutionarySearch : public SearchStrategy { |
776 | public: |
777 | TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(EvolutionarySearch, SearchStrategy, |
778 | EvolutionarySearchNode); |
779 | }; |
780 | |
781 | Array<Schedule> EvolutionarySearchSampleInitPopulation(EvolutionarySearch self, int num) { |
782 | std::vector<Schedule> results = self->state_->SampleInitPopulation(num); |
783 | return Array<Schedule>(results.begin(), results.end()); |
784 | } |
785 | |
786 | Array<Schedule> EvolutionarySearchEvolveWithCostModel(EvolutionarySearch self, |
787 | Array<Schedule> population, int num) { |
788 | Array<Schedule> result; |
789 | std::vector<Schedule> population_vec = |
790 | std::vector<Schedule>(population.begin(), population.end()); |
791 | std::vector<Schedule> schs = self->state_->EvolveWithCostModel(population_vec, num); |
792 | for (Schedule sch : schs) { |
793 | IRModule mod = sch->mod(); |
794 | size_t shash = self->state_->ModuleHash(mod); |
795 | if (!self->state_->measured_workloads_.Has(mod, shash)) { |
796 | self->state_->measured_workloads_.Add(mod, shash); |
797 | result.push_back(sch); |
798 | } |
799 | } |
800 | return result; |
801 | } |
802 | |
803 | TVM_REGISTER_NODE_TYPE(EvolutionarySearchNode); |
804 | TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyEvolutionarySearch" ) |
805 | .set_body_typed(SearchStrategy::EvolutionarySearch); |
806 | TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyEvolutionarySearchSampleInitPopulation" ) |
807 | .set_body_typed(EvolutionarySearchSampleInitPopulation); |
808 | TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyEvolutionarySearchEvolveWithCostModel" ) |
809 | .set_body_typed(EvolutionarySearchEvolveWithCostModel); |
810 | |
811 | } // namespace meta_schedule |
812 | } // namespace tvm |
813 | |