1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file auto_scheduler/search_policy/sketch_search_policy.h
22 * \brief The search policy that searches in a hierarchical search space defined by sketches.
23 * The policy randomly samples programs from the space defined by sketches
24 * and use evolutionary search to fine-tune them.
25 */
26
27#include "sketch_policy.h"
28
29#include <tvm/runtime/registry.h>
30#include <tvm/support/parallel_for.h>
31
32#include <algorithm>
33#include <iomanip>
34#include <limits>
35#include <memory>
36#include <queue>
37#include <set>
38#include <string>
39#include <unordered_map>
40#include <unordered_set>
41#include <utility>
42#include <vector>
43
44#include "sketch_policy_rules.h"
45
46namespace tvm {
47namespace auto_scheduler {
48
49/********** Sketch generation rules **********/
50static RuleSkipStage rule_skip_stage;
51static RuleAlwaysInline rule_always_inline;
52static RuleMultiLevelTiling rule_multi_level_tiling;
53static RuleMultiLevelTilingWithFusion rule_multi_level_tiling_with_fusion;
54static RuleAddCacheRead rule_add_cache_read_stage;
55static RuleAddCacheWrite rule_add_cache_write_stage;
56static RuleAddRfactor rule_add_rfactor;
57static RuleCrossThreadReduction rule_cross_thread_reduction;
58static RuleSimplifyComputeWithConstTensor rule_simplify_compute_with_const_tensor;
59static RuleSpecialComputeLocationGPU rule_special_compute_location_gpu;
60
61/********** Init population rules **********/
62static InitFillTileSize init_fill_tile_size;
63static InitChangeComputeLocation init_change_compute_location;
64static InitParallel init_parallel;
65static InitUnroll init_unroll;
66static InitVectorization init_vectorization;
67static InitThreadBind init_thread_bind;
68
69/********** Sketch policy **********/
70TVM_REGISTER_NODE_TYPE(SketchPolicyNode);
71
72SketchPolicy::SketchPolicy(SearchTask task, CostModel program_cost_model,
73 Map<String, ObjectRef> params, int seed, int verbose,
74 Optional<Array<SearchCallback>> init_search_callbacks) {
75 auto node = make_object<SketchPolicyNode>();
76 node->search_task = std::move(task);
77 node->program_cost_model = std::move(program_cost_model);
78 node->rand_gen = std::mt19937(seed);
79 node->params = std::move(params);
80 node->verbose = verbose;
81 node->sample_init_min_pop_ =
82 GetIntParam(node->params, SketchParamKey::SampleInitPopulation::min_population);
83
84 if (init_search_callbacks) {
85 PrintTitle("Call init-search callbacks", verbose);
86 // Candidates:
87 // - auto_scheduler.PreloadMeasuredStates: Load already measured states to
88 // `measured_states_set_`, `measured_states_vector_` and `measured_states_throughputs_`.
89 // - auto_scheduler.PreloadCustomSketchRule: Add user custom sketch rules to `sketch_rules`,
90 // these rules will be processed prior to the default rules.
91 node->RunCallbacks(init_search_callbacks.value());
92 }
93
94 // NOTE: There are strong dependency among the rules below,
95 // so the order to push them into the vector should be considered carefully.
96 if (IsCPUTask(node->search_task)) {
97 // Sketch Generation Rules
98 node->sketch_rules.push_back(&rule_always_inline);
99 node->sketch_rules.push_back(&rule_simplify_compute_with_const_tensor);
100 node->sketch_rules.push_back(&rule_add_rfactor);
101 node->sketch_rules.push_back(&rule_add_cache_write_stage);
102 node->sketch_rules.push_back(&rule_multi_level_tiling_with_fusion);
103 node->sketch_rules.push_back(&rule_multi_level_tiling);
104 node->sketch_rules.push_back(&rule_skip_stage);
105
106 // Initial Population Generation Rules
107 node->init_rules.push_back(&init_fill_tile_size);
108 node->init_rules.push_back(&init_change_compute_location);
109 node->init_rules.push_back(&init_parallel);
110 node->init_rules.push_back(&init_unroll);
111 node->init_rules.push_back(&init_vectorization);
112
113 // Mutation Rules for Evolutionary Search
114 node->mutation_rules.push_back(std::make_shared<MutateTileSize>(0.90));
115 node->mutation_rules.push_back(std::make_shared<MutateAutoUnroll>(0.04));
116 node->mutation_rules.push_back(std::make_shared<MutateComputeLocation>(0.05));
117 node->mutation_rules.push_back(std::make_shared<MutateParallel>(0.01));
118 } else if (IsGPUTask(node->search_task)) {
119 // Sketch Generation Rules
120 if (node->search_task->target->GetAttr<String>("device", "") == "mali") {
121 node->sketch_rules.push_back(&rule_always_inline);
122 node->sketch_rules.push_back(&rule_simplify_compute_with_const_tensor);
123 node->sketch_rules.push_back(&rule_add_rfactor);
124 node->sketch_rules.push_back(&rule_add_cache_write_stage);
125 node->sketch_rules.push_back(&rule_multi_level_tiling_with_fusion);
126 node->sketch_rules.push_back(&rule_multi_level_tiling);
127 node->sketch_rules.push_back(&rule_skip_stage);
128 } else {
129 node->sketch_rules.push_back(&rule_add_cache_read_stage);
130 node->sketch_rules.push_back(&rule_special_compute_location_gpu);
131 node->sketch_rules.push_back(&rule_always_inline);
132 node->sketch_rules.push_back(&rule_simplify_compute_with_const_tensor);
133 node->sketch_rules.push_back(&rule_cross_thread_reduction);
134 node->sketch_rules.push_back(&rule_add_cache_write_stage);
135 node->sketch_rules.push_back(&rule_multi_level_tiling_with_fusion);
136 node->sketch_rules.push_back(&rule_multi_level_tiling);
137 node->sketch_rules.push_back(&rule_skip_stage);
138 }
139
140 // Initial Population Generation Rules
141 node->init_rules.push_back(&init_fill_tile_size);
142 node->init_rules.push_back(&init_thread_bind);
143 node->init_rules.push_back(&init_unroll);
144
145 if (node->search_task->target->GetAttr<String>("device", "") == "mali") {
146 node->init_rules.push_back(&init_vectorization);
147 }
148
149 // Mutation Rules for Evolutionary Search
150 node->mutation_rules.push_back(std::make_shared<MutateTileSize>(0.90));
151 node->mutation_rules.push_back(std::make_shared<MutateAutoUnroll>(0.10));
152 } else {
153 LOG(FATAL) << "No default sketch rules for target: " << node->search_task->target;
154 }
155
156 data_ = std::move(node);
157}
158
159State SketchPolicyNode::Search(int n_trials, int early_stopping, int num_measure_per_iter,
160 ProgramMeasurer measurer) {
161 num_measure_per_iter_ = num_measure_per_iter;
162
163 if (n_trials <= 1) {
164 // No measurement is allowed
165 const Array<State>& best_states = SearchOneRound(0);
166 ICHECK_GT(best_states.size(), 0);
167 return best_states[0];
168 } else {
169 int num_random =
170 static_cast<int>(GetDoubleParam(params, SketchParamKey::eps_greedy) * num_measure_per_iter);
171 early_stopping = early_stopping < 0 ? std::numeric_limits<int>::max() >> 1 : early_stopping;
172 measurer->Reset();
173
174 int ct = 0;
175 int empty_retry_count = GetIntParam(params, SketchParamKey::empty_retry_count);
176 Array<State> best_states, random_states;
177 Array<MeasureInput> inputs;
178 Array<MeasureResult> results;
179 while (ct < n_trials) {
180 if (!inputs.empty()) {
181 auto t_begin = std::chrono::high_resolution_clock::now();
182
183 // Retrain the cost model before the next search round
184 PrintTitle("Train cost model", verbose);
185 program_cost_model->Update(inputs, results);
186
187 PrintTimeElapsed(t_begin, "training", verbose);
188 }
189
190 // Search one round to get promising states
191 PrintTitle("Search", verbose);
192 best_states = SearchOneRound(num_random * 3, &random_states);
193
194 // Infer bound. This is necessary for computing the correct ToStr() for redundancy check
195 best_states = search_task->compute_dag.InferBound(best_states);
196 random_states = search_task->compute_dag.InferBound(random_states);
197
198 // Pick `num_measure_per_iter` states to measure, check hash to remove already measured state
199 // Also pick some random states to do eps-greedy
200 inputs = PickStatesWithEpsGreedy(best_states, random_states, n_trials - ct);
201
202 // Currently it's hard to detect if all of the search space has been traversed
203 // Stop if no extra valid states found in several retries
204 if (inputs.empty()) {
205 if (empty_retry_count-- > 0) {
206 continue;
207 } else {
208 StdCout(verbose) << "It seems all candidates in the search space have been measured."
209 << std::endl;
210 break;
211 }
212 } else {
213 // Reset the retry count
214 empty_retry_count = GetIntParam(params, SketchParamKey::empty_retry_count);
215 }
216
217 // Measure candidate states
218 PrintTitle("Measure", verbose);
219 results = measurer->Measure(search_task, GetRef<SearchPolicy>(this), inputs);
220 ct += inputs.size();
221
222 // Check if reach the early stopping condition
223 if (ct - measurer->best_ct[search_task->workload_key] > early_stopping &&
224 measurer->has_valid.count(search_task->workload_key)) {
225 StdCout(verbose) << "Stop early since no performance improvement in the last "
226 << early_stopping << " measurements trials.\n";
227 break;
228 }
229
230 // Update measured states throughputs. These states will join the EvolutionarySearch in later
231 // search rounds.
232 for (const auto& res : results) {
233 measured_states_throughputs_.push_back(1.0 / FloatArrayMean(res->costs));
234 }
235 }
236 PrintTitle("Done", verbose);
237
238 return measurer->best_state[search_task->workload_key];
239 }
240}
241
242std::pair<Array<MeasureInput>, Array<MeasureResult>> SketchPolicyNode::ContinueSearchOneRound(
243 int num_measure, ProgramMeasurer measurer) {
244 num_measure_per_iter_ = num_measure;
245
246 Array<State> best_states, random_states;
247 Array<MeasureInput> inputs;
248 Array<MeasureResult> results;
249 int num_random = static_cast<int>(GetDoubleParam(params, "eps_greedy") * num_measure);
250
251 // Search one round to get promising states
252 PrintTitle("Search", verbose);
253 best_states = SearchOneRound(num_random * 3, &random_states);
254
255 // Infer bound. This is necessary for computing the correct ToStr() for redundancy check
256 best_states = search_task->compute_dag.InferBound(best_states);
257 random_states = search_task->compute_dag.InferBound(random_states);
258
259 // Pick `num_measure_per_iter` states to measure, check hash to remove already measured state
260 // Also pick some random states to do eps-greedy
261 inputs = PickStatesWithEpsGreedy(best_states, random_states, num_measure);
262
263 // Measure candidate states
264 PrintTitle("Measure", verbose);
265 results = measurer->Measure(search_task, GetRef<SearchPolicy>(this), inputs);
266
267 // Update measured states throughputs. These states will join the EvolutionarySearch in later
268 // search rounds.
269 for (const auto& res : results) {
270 measured_states_throughputs_.push_back(1.0 / FloatArrayMean(res->costs));
271 }
272
273 auto t_begin = std::chrono::high_resolution_clock::now();
274
275 // Update the cost model
276 PrintTitle("Train cost model", verbose);
277 program_cost_model->Update(inputs, results);
278
279 PrintTimeElapsed(t_begin, "training", verbose);
280
281 return std::make_pair(std::move(inputs), std::move(results));
282}
283
284Array<State> SketchPolicyNode::SearchOneRound(int num_random_states, Array<State>* random_states) {
285 // Get parameters
286 int population = GetIntParam(params, SketchParamKey::EvolutionarySearch::population);
287 int num_use_measured = std::min(
288 static_cast<int>(measured_states_vector_.size()),
289 static_cast<int>(
290 GetDoubleParam(params, SketchParamKey::SampleInitPopulation::use_measured_ratio) *
291 population));
292
293 // 1. Generate sketches
294 if (sketch_cache_.empty()) {
295 sketch_cache_ = GenerateSketches();
296 }
297
298 // 2. Sample the init population
299 Array<State> init_population = SampleInitPopulation(sketch_cache_);
300
301 // 3. Perform evolutionary search.
302 // Also insert already measured good states to the initial population
303 std::vector<int> indices = Argsort(measured_states_throughputs_);
304 for (int i = 0; i < num_use_measured; i++) {
305 init_population.push_back(measured_states_vector_[indices[i]]);
306 }
307 // Sample some random states for eps-greedy
308 if (num_random_states > 0 && random_states != nullptr) {
309 *random_states = RandomSampleStates(init_population, &rand_gen, num_random_states);
310 }
311 return EvolutionarySearch(init_population, num_measure_per_iter_ * 2);
312}
313
314Array<State> SketchPolicyNode::GenerateSketches() {
315 const State& init_state = search_task->compute_dag->init_state;
316
317 // Two ping pong buffers to avoid copy
318 Array<State> states_buf1{init_state}, states_buf2;
319 Array<State>* pnow = &states_buf1;
320 Array<State>* pnext = &states_buf2;
321
322 // A map that maps state to its current working position (stage_id)
323 std::unordered_map<State, int, ObjectHash, ObjectEqual> cur_stage_id_map;
324 cur_stage_id_map[init_state] = static_cast<int>(init_state->stages.size()) - 1;
325
326 // Derivation rule based enumeration
327 Array<State> out_states;
328 while (!pnow->empty()) {
329 pnext->clear();
330 for (const State& state : *pnow) {
331 int stage_id = cur_stage_id_map[state];
332
333 // Reaches to the terminal stage
334 if (stage_id < 0) {
335 out_states.push_back(state);
336 continue;
337 }
338
339 // Try all derivation rules
340 for (const auto& rule : sketch_rules) {
341 auto cond = rule->MeetCondition(*this, state, stage_id);
342 if (cond != SketchGenerationRule::ConditionKind::kSkip) {
343 for (const auto& pair : rule->Apply(*this, state, stage_id)) {
344 cur_stage_id_map[pair.first] = pair.second;
345 pnext->push_back(pair.first);
346 }
347 // Skip the rest rules
348 if (cond == SketchGenerationRule::ConditionKind::kApplyAndSkipRest) {
349 break;
350 }
351 }
352 }
353 }
354 std::swap(pnow, pnext);
355 }
356
357 // Hack for rfactor: Replace the split factor for rfactor to the undefined Expr(),
358 // so later we can sample random value for the split factor.
359 // Why don't we use Expr() when doing the split for rfactor at the first time?
360 // Because during ApplySteps, a rfactor with undefined Expr() will crash TVM.
361 // So rfactor with undefined Expr() will conflict with cache_write, cache_read, rfactor
362 // in other stages
363 for (size_t i = 0; i < out_states.size(); ++i) {
364 auto state = out_states[i];
365 auto pstate = state.CopyOnWrite();
366 for (size_t step_id = 0; step_id < pstate->transform_steps.size(); ++step_id) {
367 if (pstate->transform_steps[step_id]->IsInstance<RfactorStepNode>()) {
368 ICHECK_GE(step_id, 1);
369 int split_step_id = static_cast<int>(step_id - 1);
370 auto step = pstate->transform_steps[split_step_id].as<SplitStepNode>();
371 ICHECK(step != nullptr);
372 pstate->transform_steps.Set(
373 split_step_id, SplitStep(step->stage_id, step->iter_id, step->extent, {NullOpt},
374 step->inner_to_outer));
375 }
376 }
377 out_states.Set(i, std::move(state));
378 }
379
380 StdCout(verbose) << "Generate Sketches\t\t#s: " << out_states.size() << std::endl;
381 return out_states;
382}
383
384Array<State> SketchPolicyNode::SampleInitPopulation(const Array<State>& sketches) {
385 // Use this population as the parallel degree to do sampling
386 int population = GetIntParam(params, SketchParamKey::EvolutionarySearch::population);
387
388 auto tic_begin = std::chrono::high_resolution_clock::now();
389
390 int fail_ct = 0;
391 Array<State> out_states;
392 std::vector<std::mt19937> rand_gens;
393 rand_gens.reserve(population);
394 for (int i = 0; i < population; i++) {
395 rand_gens.push_back(std::mt19937(rand_gen()));
396 }
397
398 std::unordered_set<std::string> explored_state_strs;
399 size_t iter = 1;
400 size_t unchange_cnt = 0;
401 while (static_cast<int>(out_states.size()) < sample_init_min_pop_) {
402 std::vector<State> temp_states(population);
403
404 // Sample a batch of states randomly
405 support::parallel_for(0, population, [this, &temp_states, &sketches, &rand_gens](int index) {
406 // Randomly choose a sketch
407 State tmp_s = sketches[(rand_gens[index])() % sketches.size()];
408 // Apply random annotation rules one by one
409 bool valid = true;
410 for (const auto& rule : init_rules) {
411 if (rule->Apply(this, &tmp_s, &rand_gens[index]) ==
412 PopulationGenerationRule::ResultKind::kInvalid) {
413 valid = false;
414 break;
415 }
416 }
417 if (valid) {
418 temp_states[index] = std::move(tmp_s);
419 }
420 });
421
422 // Filter out the states that were failed to apply initial rules
423 Array<State> cand_states;
424 for (auto tmp_s : temp_states) {
425 if (tmp_s.defined()) {
426 cand_states.push_back(std::move(tmp_s));
427 } else {
428 fail_ct++;
429 }
430 }
431
432 unchange_cnt++;
433 if (!cand_states.empty()) {
434 // Run the cost model to make filter out states that failed to extract features.
435 // This may happen due to illegal schedules or the schedules that uses too much
436 // memory on GPU.
437 std::vector<float> pop_scores;
438 pop_scores.reserve(cand_states.size());
439 cand_states = search_task->compute_dag.InferBound(cand_states);
440 PruneInvalidState(search_task, &cand_states);
441 program_cost_model->Predict(search_task, cand_states, &pop_scores);
442
443 for (size_t i = 0; i < cand_states.size(); i++) {
444 const auto state_str = cand_states[i].ToStr();
445 if (pop_scores[i] > -1e10 && explored_state_strs.count(state_str) == 0) {
446 explored_state_strs.insert(state_str);
447 out_states.push_back(std::move(cand_states[i]));
448 unchange_cnt = 0; // Reset the counter once we found a valid state
449 } else {
450 fail_ct++;
451 }
452 }
453 }
454
455 if (iter % 5 == 0) {
456 double duration = std::chrono::duration_cast<std::chrono::duration<double>>(
457 std::chrono::high_resolution_clock::now() - tic_begin)
458 .count();
459 StdCout(verbose) << "Sample Iter: " << iter << std::fixed << std::setprecision(4)
460 << "\t#Pop: " << out_states.size() << "\t#Target: " << sample_init_min_pop_
461 << "\tfail_ct: " << fail_ct << "\tTime elapsed: " << std::fixed
462 << std::setprecision(2) << duration << std::endl;
463 }
464
465 if (unchange_cnt == 5) {
466 // Reduce the target size to avoid too-long time in this phase if no valid state was found
467 // in the past iterations
468 if (sample_init_min_pop_ > 1) {
469 sample_init_min_pop_ /= 2;
470 StdCout(verbose) << "#Target has been reduced to " << sample_init_min_pop_
471 << " due to too many failures or duplications" << std::endl;
472 }
473 unchange_cnt = 0;
474 }
475 iter++;
476 }
477
478 double duration = std::chrono::duration_cast<std::chrono::duration<double>>(
479 std::chrono::high_resolution_clock::now() - tic_begin)
480 .count();
481 StdCout(verbose) << "Sample Initial Population\t#s: " << out_states.size()
482 << "\tfail_ct: " << fail_ct << "\tTime elapsed: " << std::fixed
483 << std::setprecision(2) << duration << std::endl;
484 return out_states;
485}
486
487Array<State> SketchPolicyNode::EvolutionarySearch(const Array<State>& init_population,
488 int out_size) {
489 Array<State> best_states;
490 auto tic_begin = std::chrono::high_resolution_clock::now();
491
492 size_t population = GetIntParam(params, SketchParamKey::EvolutionarySearch::population);
493 double mutation_prob = GetDoubleParam(params, SketchParamKey::EvolutionarySearch::mutation_prob);
494 int num_iters = GetIntParam(params, SketchParamKey::EvolutionarySearch::num_iters);
495
496 bool is_cost_model_reasonable = !program_cost_model->IsInstance<RandomModelNode>();
497 if (!is_cost_model_reasonable && num_iters > 2) {
498 num_iters = 2;
499 StdCout(verbose) << "GA iteration number has been adjusted to " << num_iters
500 << " due to random cost model" << std::endl;
501 }
502
503 // Two ping pong buffers to avoid copy.
504 Array<State> states_buf1{init_population}, states_buf2;
505 states_buf1.reserve(population);
506 states_buf2.reserve(population);
507 Array<State>* pnow = &states_buf1;
508 Array<State>* pnext = &states_buf2;
509
510 // A heap to keep the best states during evolution
511 using StateHeapItem = std::pair<State, float>;
512 auto cmp = [](const StateHeapItem& left, const StateHeapItem& right) {
513 return left.second > right.second;
514 };
515 std::vector<StateHeapItem> heap;
516 std::unordered_set<std::string> in_heap(measured_states_set_);
517 heap.reserve(out_size);
518
519 // auxiliary global variables
520 std::vector<float> pop_scores;
521 std::vector<double> pop_selection_probs;
522 float max_score = -1e-10f;
523 pop_scores.reserve(population);
524 pop_selection_probs.reserve(population);
525 std::uniform_real_distribution<> dis(0.0, 1.0);
526
527 // mutation rules
528 int mutation_success_ct, mutation_fail_ct;
529 mutation_success_ct = mutation_fail_ct = 0;
530 std::vector<float> rule_weights;
531 std::vector<double> rule_selection_probs;
532 for (const auto& rule : mutation_rules) {
533 rule_weights.push_back(rule->weight);
534 }
535 ComputePrefixSumProb(rule_weights, &rule_selection_probs);
536
537 // Genetic Algorithm
538 for (int k = 0; k < num_iters + 1; ++k) {
539 // Maintain the heap
540 *pnow = search_task->compute_dag.InferBound(*pnow);
541 PruneInvalidState(search_task, pnow);
542 program_cost_model->Predict(search_task, *pnow, &pop_scores);
543
544 for (size_t i = 0; i < pnow->size(); ++i) {
545 const State& state = (*pnow)[i];
546 std::string state_str = state.ToStr();
547
548 if (in_heap.count(state_str) == 0) {
549 if (static_cast<int>(heap.size()) < out_size) {
550 heap.emplace_back((*pnow)[i], pop_scores[i]);
551 std::push_heap(heap.begin(), heap.end(), cmp);
552 in_heap.insert(state_str);
553 } else if (pop_scores[i] > heap.front().second) {
554 std::string old_state_str = heap.front().first.ToStr();
555 in_heap.erase(old_state_str);
556 in_heap.insert(state_str);
557
558 std::pop_heap(heap.begin(), heap.end(), cmp);
559 heap.back() = StateHeapItem(state, pop_scores[i]);
560 std::push_heap(heap.begin(), heap.end(), cmp);
561 }
562 if (pop_scores[i] > max_score) {
563 max_score = pop_scores[i];
564 }
565 }
566 }
567
568 // Print statistical information
569 if (k % 5 == 0 || k == num_iters) {
570 StdCout(verbose) << "GA Iter: " << k;
571 if (!heap.empty()) {
572 StdCout(verbose) << std::fixed << std::setprecision(4) << "\tMax score: " << max_score
573 << std::fixed << std::setprecision(4)
574 << "\tMin score: " << heap.front().second;
575 } else {
576 StdCout(verbose) << "\tMax score: N/A\tMin score: N/A";
577 }
578 StdCout(verbose) << "\t#Pop: " << heap.size() << "\t#M+: " << mutation_success_ct / (k + 1)
579 << "\t#M-: " << mutation_fail_ct / (k + 1) << std::endl;
580 }
581 if (k == num_iters) {
582 break;
583 }
584
585 // Compute selection probability
586 ComputePrefixSumProb(pop_scores, &pop_selection_probs);
587
588 // TODO(merrymercy, comaniac): add crossover.
589
590 // Do mutation
591 while (pnext->size() < population) {
592 State tmp_s = (*pnow)[RandomChoose(pop_selection_probs, &rand_gen)];
593
594 if (dis(rand_gen) < mutation_prob) {
595 const auto& rule = mutation_rules[RandomChoose(rule_selection_probs, &rand_gen)];
596 if (rule->Apply(this, &tmp_s, &rand_gen) == PopulationGenerationRule::ResultKind::kValid) {
597 pnext->push_back(std::move(tmp_s));
598 mutation_success_ct++;
599 } else {
600 mutation_fail_ct++;
601 }
602 } else {
603 pnext->push_back(std::move(tmp_s));
604 }
605 }
606
607 std::swap(pnext, pnow);
608 pnext->clear();
609 }
610
611 // Copy best states in the heap to out_states
612 std::sort(heap.begin(), heap.end(), cmp);
613 for (auto& item : heap) {
614 best_states.push_back(std::move(item.first));
615 }
616
617 double duration = std::chrono::duration_cast<std::chrono::duration<double>>(
618 std::chrono::high_resolution_clock::now() - tic_begin)
619 .count();
620 StdCout(verbose) << "EvolutionarySearch\t\t#s: " << best_states.size()
621 << "\tTime elapsed: " << std::fixed << std::setprecision(2) << duration
622 << std::endl;
623 return best_states;
624}
625
626Array<MeasureInput> SketchPolicyNode::PickStatesWithEpsGreedy(const Array<State>& best_states,
627 const Array<State>& random_states,
628 int remaining_n_trials) {
629 int num_random =
630 static_cast<int>(GetDoubleParam(params, SketchParamKey::eps_greedy) * num_measure_per_iter_);
631 int num_good = num_measure_per_iter_ - num_random;
632
633 Array<MeasureInput> inputs;
634 size_t offset_best = 0, offset_random = 0;
635
636 while (static_cast<int>(inputs.size()) < std::min(num_measure_per_iter_, remaining_n_trials)) {
637 State state;
638
639 bool has_best = offset_best < best_states.size();
640 bool has_random = offset_random < random_states.size();
641
642 if (static_cast<int>(inputs.size()) < num_good) {
643 // prefer best states
644 if (has_best) {
645 state = best_states[offset_best++];
646 } else if (has_random) {
647 state = random_states[offset_random++];
648 } else {
649 break;
650 }
651 } else {
652 // prefer random states
653 if (has_random) {
654 state = random_states[offset_random++];
655 } else if (has_best) {
656 state = best_states[offset_best++];
657 } else {
658 break;
659 }
660 }
661
662 // Check if it has already been measured
663 std::string state_str = state.ToStr();
664 if (!measured_states_set_.count(state_str)) {
665 measured_states_set_.insert(std::move(state_str));
666 measured_states_vector_.push_back(state);
667 inputs.push_back(MeasureInput(search_task, state));
668 }
669 }
670
671 return inputs;
672}
673
674/********** PreloadCustomSketchRule **********/
675TVM_REGISTER_OBJECT_TYPE(PreloadCustomSketchRuleNode);
676
677PreloadCustomSketchRule::PreloadCustomSketchRule(PackedFunc meet_condition_func,
678 PackedFunc apply_func, String rule_name) {
679 auto node = make_object<PreloadCustomSketchRuleNode>();
680 node->meet_condition_func = std::move(meet_condition_func);
681 node->apply_func = std::move(apply_func);
682 node->rule_name = std::move(rule_name);
683 data_ = std::move(node);
684}
685
686void PreloadCustomSketchRuleNode::Callback(SearchPolicyNode* policy) {
687 CHECK(policy->IsInstance<SketchPolicyNode>());
688 auto sketch_policy = dynamic_cast<SketchPolicyNode*>(policy);
689 sketch_policy->sketch_rules.push_back(
690 new RuleCustomSketch(meet_condition_func, apply_func, rule_name));
691 StdCout(policy->verbose) << "Custom sketch rule \"" << rule_name << "\" added." << std::endl;
692}
693
694TVM_REGISTER_GLOBAL("auto_scheduler.SketchPolicy")
695 .set_body_typed([](SearchTask task, CostModel program_cost_model, Map<String, ObjectRef> params,
696 int seed, int verbose,
697 Optional<Array<SearchCallback>> init_search_callbacks) {
698 return SketchPolicy(task, program_cost_model, params, seed, verbose, init_search_callbacks);
699 });
700
701TVM_REGISTER_GLOBAL("auto_scheduler.SketchPolicyGenerateSketches")
702 .set_body_typed([](SketchPolicy policy) { return policy->GenerateSketches(); });
703
704TVM_REGISTER_GLOBAL("auto_scheduler.SketchPolicySampleInitialPopulation")
705 .set_body_typed([](SketchPolicy policy) {
706 const Array<State>& sketches = policy->GenerateSketches();
707
708 Array<State> init_population = policy->SampleInitPopulation(sketches);
709 return init_population;
710 });
711
712TVM_REGISTER_GLOBAL("auto_scheduler.SketchPolicyEvolutionarySearch")
713 .set_body_typed([](SketchPolicy policy, Array<State> init_population, int out_size) {
714 Array<State> states = policy->EvolutionarySearch(init_population, out_size);
715 return states;
716 });
717
718TVM_REGISTER_GLOBAL("auto_scheduler.PrintTitle").set_body_typed([](std::string title) {
719 PrintTitle(title, 1);
720});
721
722TVM_REGISTER_GLOBAL("auto_scheduler.PreloadCustomSketchRule")
723 .set_body_typed([](PackedFunc meet_condition_func, PackedFunc apply_func, String rule_name) {
724 return PreloadCustomSketchRule(meet_condition_func, apply_func, rule_name);
725 });
726
727} // namespace auto_scheduler
728} // namespace tvm
729