1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file auto_scheduler/search_policy/sketch_search_policy.h |
22 | * \brief The search policy that searches in a hierarchical search space defined by sketches. |
23 | * The policy randomly samples programs from the space defined by sketches |
24 | * and use evolutionary search to fine-tune them. |
25 | */ |
26 | |
27 | #include "sketch_policy.h" |
28 | |
29 | #include <tvm/runtime/registry.h> |
30 | #include <tvm/support/parallel_for.h> |
31 | |
32 | #include <algorithm> |
33 | #include <iomanip> |
34 | #include <limits> |
35 | #include <memory> |
36 | #include <queue> |
37 | #include <set> |
38 | #include <string> |
39 | #include <unordered_map> |
40 | #include <unordered_set> |
41 | #include <utility> |
42 | #include <vector> |
43 | |
44 | #include "sketch_policy_rules.h" |
45 | |
46 | namespace tvm { |
47 | namespace auto_scheduler { |
48 | |
49 | /********** Sketch generation rules **********/ |
50 | static RuleSkipStage rule_skip_stage; |
51 | static RuleAlwaysInline rule_always_inline; |
52 | static RuleMultiLevelTiling rule_multi_level_tiling; |
53 | static RuleMultiLevelTilingWithFusion rule_multi_level_tiling_with_fusion; |
54 | static RuleAddCacheRead rule_add_cache_read_stage; |
55 | static RuleAddCacheWrite rule_add_cache_write_stage; |
56 | static RuleAddRfactor rule_add_rfactor; |
57 | static RuleCrossThreadReduction rule_cross_thread_reduction; |
58 | static RuleSimplifyComputeWithConstTensor rule_simplify_compute_with_const_tensor; |
59 | static RuleSpecialComputeLocationGPU rule_special_compute_location_gpu; |
60 | |
61 | /********** Init population rules **********/ |
62 | static InitFillTileSize init_fill_tile_size; |
63 | static InitChangeComputeLocation init_change_compute_location; |
64 | static InitParallel init_parallel; |
65 | static InitUnroll init_unroll; |
66 | static InitVectorization init_vectorization; |
67 | static InitThreadBind init_thread_bind; |
68 | |
69 | /********** Sketch policy **********/ |
70 | TVM_REGISTER_NODE_TYPE(SketchPolicyNode); |
71 | |
72 | SketchPolicy::SketchPolicy(SearchTask task, CostModel program_cost_model, |
73 | Map<String, ObjectRef> params, int seed, int verbose, |
74 | Optional<Array<SearchCallback>> init_search_callbacks) { |
75 | auto node = make_object<SketchPolicyNode>(); |
76 | node->search_task = std::move(task); |
77 | node->program_cost_model = std::move(program_cost_model); |
78 | node->rand_gen = std::mt19937(seed); |
79 | node->params = std::move(params); |
80 | node->verbose = verbose; |
81 | node->sample_init_min_pop_ = |
82 | GetIntParam(node->params, SketchParamKey::SampleInitPopulation::min_population); |
83 | |
84 | if (init_search_callbacks) { |
85 | PrintTitle("Call init-search callbacks" , verbose); |
86 | // Candidates: |
87 | // - auto_scheduler.PreloadMeasuredStates: Load already measured states to |
88 | // `measured_states_set_`, `measured_states_vector_` and `measured_states_throughputs_`. |
89 | // - auto_scheduler.PreloadCustomSketchRule: Add user custom sketch rules to `sketch_rules`, |
90 | // these rules will be processed prior to the default rules. |
91 | node->RunCallbacks(init_search_callbacks.value()); |
92 | } |
93 | |
94 | // NOTE: There are strong dependency among the rules below, |
95 | // so the order to push them into the vector should be considered carefully. |
96 | if (IsCPUTask(node->search_task)) { |
97 | // Sketch Generation Rules |
98 | node->sketch_rules.push_back(&rule_always_inline); |
99 | node->sketch_rules.push_back(&rule_simplify_compute_with_const_tensor); |
100 | node->sketch_rules.push_back(&rule_add_rfactor); |
101 | node->sketch_rules.push_back(&rule_add_cache_write_stage); |
102 | node->sketch_rules.push_back(&rule_multi_level_tiling_with_fusion); |
103 | node->sketch_rules.push_back(&rule_multi_level_tiling); |
104 | node->sketch_rules.push_back(&rule_skip_stage); |
105 | |
106 | // Initial Population Generation Rules |
107 | node->init_rules.push_back(&init_fill_tile_size); |
108 | node->init_rules.push_back(&init_change_compute_location); |
109 | node->init_rules.push_back(&init_parallel); |
110 | node->init_rules.push_back(&init_unroll); |
111 | node->init_rules.push_back(&init_vectorization); |
112 | |
113 | // Mutation Rules for Evolutionary Search |
114 | node->mutation_rules.push_back(std::make_shared<MutateTileSize>(0.90)); |
115 | node->mutation_rules.push_back(std::make_shared<MutateAutoUnroll>(0.04)); |
116 | node->mutation_rules.push_back(std::make_shared<MutateComputeLocation>(0.05)); |
117 | node->mutation_rules.push_back(std::make_shared<MutateParallel>(0.01)); |
118 | } else if (IsGPUTask(node->search_task)) { |
119 | // Sketch Generation Rules |
120 | if (node->search_task->target->GetAttr<String>("device" , "" ) == "mali" ) { |
121 | node->sketch_rules.push_back(&rule_always_inline); |
122 | node->sketch_rules.push_back(&rule_simplify_compute_with_const_tensor); |
123 | node->sketch_rules.push_back(&rule_add_rfactor); |
124 | node->sketch_rules.push_back(&rule_add_cache_write_stage); |
125 | node->sketch_rules.push_back(&rule_multi_level_tiling_with_fusion); |
126 | node->sketch_rules.push_back(&rule_multi_level_tiling); |
127 | node->sketch_rules.push_back(&rule_skip_stage); |
128 | } else { |
129 | node->sketch_rules.push_back(&rule_add_cache_read_stage); |
130 | node->sketch_rules.push_back(&rule_special_compute_location_gpu); |
131 | node->sketch_rules.push_back(&rule_always_inline); |
132 | node->sketch_rules.push_back(&rule_simplify_compute_with_const_tensor); |
133 | node->sketch_rules.push_back(&rule_cross_thread_reduction); |
134 | node->sketch_rules.push_back(&rule_add_cache_write_stage); |
135 | node->sketch_rules.push_back(&rule_multi_level_tiling_with_fusion); |
136 | node->sketch_rules.push_back(&rule_multi_level_tiling); |
137 | node->sketch_rules.push_back(&rule_skip_stage); |
138 | } |
139 | |
140 | // Initial Population Generation Rules |
141 | node->init_rules.push_back(&init_fill_tile_size); |
142 | node->init_rules.push_back(&init_thread_bind); |
143 | node->init_rules.push_back(&init_unroll); |
144 | |
145 | if (node->search_task->target->GetAttr<String>("device" , "" ) == "mali" ) { |
146 | node->init_rules.push_back(&init_vectorization); |
147 | } |
148 | |
149 | // Mutation Rules for Evolutionary Search |
150 | node->mutation_rules.push_back(std::make_shared<MutateTileSize>(0.90)); |
151 | node->mutation_rules.push_back(std::make_shared<MutateAutoUnroll>(0.10)); |
152 | } else { |
153 | LOG(FATAL) << "No default sketch rules for target: " << node->search_task->target; |
154 | } |
155 | |
156 | data_ = std::move(node); |
157 | } |
158 | |
159 | State SketchPolicyNode::Search(int n_trials, int early_stopping, int num_measure_per_iter, |
160 | ProgramMeasurer measurer) { |
161 | num_measure_per_iter_ = num_measure_per_iter; |
162 | |
163 | if (n_trials <= 1) { |
164 | // No measurement is allowed |
165 | const Array<State>& best_states = SearchOneRound(0); |
166 | ICHECK_GT(best_states.size(), 0); |
167 | return best_states[0]; |
168 | } else { |
169 | int num_random = |
170 | static_cast<int>(GetDoubleParam(params, SketchParamKey::eps_greedy) * num_measure_per_iter); |
171 | early_stopping = early_stopping < 0 ? std::numeric_limits<int>::max() >> 1 : early_stopping; |
172 | measurer->Reset(); |
173 | |
174 | int ct = 0; |
175 | int empty_retry_count = GetIntParam(params, SketchParamKey::empty_retry_count); |
176 | Array<State> best_states, random_states; |
177 | Array<MeasureInput> inputs; |
178 | Array<MeasureResult> results; |
179 | while (ct < n_trials) { |
180 | if (!inputs.empty()) { |
181 | auto t_begin = std::chrono::high_resolution_clock::now(); |
182 | |
183 | // Retrain the cost model before the next search round |
184 | PrintTitle("Train cost model" , verbose); |
185 | program_cost_model->Update(inputs, results); |
186 | |
187 | PrintTimeElapsed(t_begin, "training" , verbose); |
188 | } |
189 | |
190 | // Search one round to get promising states |
191 | PrintTitle("Search" , verbose); |
192 | best_states = SearchOneRound(num_random * 3, &random_states); |
193 | |
194 | // Infer bound. This is necessary for computing the correct ToStr() for redundancy check |
195 | best_states = search_task->compute_dag.InferBound(best_states); |
196 | random_states = search_task->compute_dag.InferBound(random_states); |
197 | |
198 | // Pick `num_measure_per_iter` states to measure, check hash to remove already measured state |
199 | // Also pick some random states to do eps-greedy |
200 | inputs = PickStatesWithEpsGreedy(best_states, random_states, n_trials - ct); |
201 | |
202 | // Currently it's hard to detect if all of the search space has been traversed |
203 | // Stop if no extra valid states found in several retries |
204 | if (inputs.empty()) { |
205 | if (empty_retry_count-- > 0) { |
206 | continue; |
207 | } else { |
208 | StdCout(verbose) << "It seems all candidates in the search space have been measured." |
209 | << std::endl; |
210 | break; |
211 | } |
212 | } else { |
213 | // Reset the retry count |
214 | empty_retry_count = GetIntParam(params, SketchParamKey::empty_retry_count); |
215 | } |
216 | |
217 | // Measure candidate states |
218 | PrintTitle("Measure" , verbose); |
219 | results = measurer->Measure(search_task, GetRef<SearchPolicy>(this), inputs); |
220 | ct += inputs.size(); |
221 | |
222 | // Check if reach the early stopping condition |
223 | if (ct - measurer->best_ct[search_task->workload_key] > early_stopping && |
224 | measurer->has_valid.count(search_task->workload_key)) { |
225 | StdCout(verbose) << "Stop early since no performance improvement in the last " |
226 | << early_stopping << " measurements trials.\n" ; |
227 | break; |
228 | } |
229 | |
230 | // Update measured states throughputs. These states will join the EvolutionarySearch in later |
231 | // search rounds. |
232 | for (const auto& res : results) { |
233 | measured_states_throughputs_.push_back(1.0 / FloatArrayMean(res->costs)); |
234 | } |
235 | } |
236 | PrintTitle("Done" , verbose); |
237 | |
238 | return measurer->best_state[search_task->workload_key]; |
239 | } |
240 | } |
241 | |
242 | std::pair<Array<MeasureInput>, Array<MeasureResult>> SketchPolicyNode::ContinueSearchOneRound( |
243 | int num_measure, ProgramMeasurer measurer) { |
244 | num_measure_per_iter_ = num_measure; |
245 | |
246 | Array<State> best_states, random_states; |
247 | Array<MeasureInput> inputs; |
248 | Array<MeasureResult> results; |
249 | int num_random = static_cast<int>(GetDoubleParam(params, "eps_greedy" ) * num_measure); |
250 | |
251 | // Search one round to get promising states |
252 | PrintTitle("Search" , verbose); |
253 | best_states = SearchOneRound(num_random * 3, &random_states); |
254 | |
255 | // Infer bound. This is necessary for computing the correct ToStr() for redundancy check |
256 | best_states = search_task->compute_dag.InferBound(best_states); |
257 | random_states = search_task->compute_dag.InferBound(random_states); |
258 | |
259 | // Pick `num_measure_per_iter` states to measure, check hash to remove already measured state |
260 | // Also pick some random states to do eps-greedy |
261 | inputs = PickStatesWithEpsGreedy(best_states, random_states, num_measure); |
262 | |
263 | // Measure candidate states |
264 | PrintTitle("Measure" , verbose); |
265 | results = measurer->Measure(search_task, GetRef<SearchPolicy>(this), inputs); |
266 | |
267 | // Update measured states throughputs. These states will join the EvolutionarySearch in later |
268 | // search rounds. |
269 | for (const auto& res : results) { |
270 | measured_states_throughputs_.push_back(1.0 / FloatArrayMean(res->costs)); |
271 | } |
272 | |
273 | auto t_begin = std::chrono::high_resolution_clock::now(); |
274 | |
275 | // Update the cost model |
276 | PrintTitle("Train cost model" , verbose); |
277 | program_cost_model->Update(inputs, results); |
278 | |
279 | PrintTimeElapsed(t_begin, "training" , verbose); |
280 | |
281 | return std::make_pair(std::move(inputs), std::move(results)); |
282 | } |
283 | |
284 | Array<State> SketchPolicyNode::SearchOneRound(int num_random_states, Array<State>* random_states) { |
285 | // Get parameters |
286 | int population = GetIntParam(params, SketchParamKey::EvolutionarySearch::population); |
287 | int num_use_measured = std::min( |
288 | static_cast<int>(measured_states_vector_.size()), |
289 | static_cast<int>( |
290 | GetDoubleParam(params, SketchParamKey::SampleInitPopulation::use_measured_ratio) * |
291 | population)); |
292 | |
293 | // 1. Generate sketches |
294 | if (sketch_cache_.empty()) { |
295 | sketch_cache_ = GenerateSketches(); |
296 | } |
297 | |
298 | // 2. Sample the init population |
299 | Array<State> init_population = SampleInitPopulation(sketch_cache_); |
300 | |
301 | // 3. Perform evolutionary search. |
302 | // Also insert already measured good states to the initial population |
303 | std::vector<int> indices = Argsort(measured_states_throughputs_); |
304 | for (int i = 0; i < num_use_measured; i++) { |
305 | init_population.push_back(measured_states_vector_[indices[i]]); |
306 | } |
307 | // Sample some random states for eps-greedy |
308 | if (num_random_states > 0 && random_states != nullptr) { |
309 | *random_states = RandomSampleStates(init_population, &rand_gen, num_random_states); |
310 | } |
311 | return EvolutionarySearch(init_population, num_measure_per_iter_ * 2); |
312 | } |
313 | |
314 | Array<State> SketchPolicyNode::GenerateSketches() { |
315 | const State& init_state = search_task->compute_dag->init_state; |
316 | |
317 | // Two ping pong buffers to avoid copy |
318 | Array<State> states_buf1{init_state}, states_buf2; |
319 | Array<State>* pnow = &states_buf1; |
320 | Array<State>* pnext = &states_buf2; |
321 | |
322 | // A map that maps state to its current working position (stage_id) |
323 | std::unordered_map<State, int, ObjectHash, ObjectEqual> cur_stage_id_map; |
324 | cur_stage_id_map[init_state] = static_cast<int>(init_state->stages.size()) - 1; |
325 | |
326 | // Derivation rule based enumeration |
327 | Array<State> out_states; |
328 | while (!pnow->empty()) { |
329 | pnext->clear(); |
330 | for (const State& state : *pnow) { |
331 | int stage_id = cur_stage_id_map[state]; |
332 | |
333 | // Reaches to the terminal stage |
334 | if (stage_id < 0) { |
335 | out_states.push_back(state); |
336 | continue; |
337 | } |
338 | |
339 | // Try all derivation rules |
340 | for (const auto& rule : sketch_rules) { |
341 | auto cond = rule->MeetCondition(*this, state, stage_id); |
342 | if (cond != SketchGenerationRule::ConditionKind::kSkip) { |
343 | for (const auto& pair : rule->Apply(*this, state, stage_id)) { |
344 | cur_stage_id_map[pair.first] = pair.second; |
345 | pnext->push_back(pair.first); |
346 | } |
347 | // Skip the rest rules |
348 | if (cond == SketchGenerationRule::ConditionKind::kApplyAndSkipRest) { |
349 | break; |
350 | } |
351 | } |
352 | } |
353 | } |
354 | std::swap(pnow, pnext); |
355 | } |
356 | |
357 | // Hack for rfactor: Replace the split factor for rfactor to the undefined Expr(), |
358 | // so later we can sample random value for the split factor. |
359 | // Why don't we use Expr() when doing the split for rfactor at the first time? |
360 | // Because during ApplySteps, a rfactor with undefined Expr() will crash TVM. |
361 | // So rfactor with undefined Expr() will conflict with cache_write, cache_read, rfactor |
362 | // in other stages |
363 | for (size_t i = 0; i < out_states.size(); ++i) { |
364 | auto state = out_states[i]; |
365 | auto pstate = state.CopyOnWrite(); |
366 | for (size_t step_id = 0; step_id < pstate->transform_steps.size(); ++step_id) { |
367 | if (pstate->transform_steps[step_id]->IsInstance<RfactorStepNode>()) { |
368 | ICHECK_GE(step_id, 1); |
369 | int split_step_id = static_cast<int>(step_id - 1); |
370 | auto step = pstate->transform_steps[split_step_id].as<SplitStepNode>(); |
371 | ICHECK(step != nullptr); |
372 | pstate->transform_steps.Set( |
373 | split_step_id, SplitStep(step->stage_id, step->iter_id, step->extent, {NullOpt}, |
374 | step->inner_to_outer)); |
375 | } |
376 | } |
377 | out_states.Set(i, std::move(state)); |
378 | } |
379 | |
380 | StdCout(verbose) << "Generate Sketches\t\t#s: " << out_states.size() << std::endl; |
381 | return out_states; |
382 | } |
383 | |
384 | Array<State> SketchPolicyNode::SampleInitPopulation(const Array<State>& sketches) { |
385 | // Use this population as the parallel degree to do sampling |
386 | int population = GetIntParam(params, SketchParamKey::EvolutionarySearch::population); |
387 | |
388 | auto tic_begin = std::chrono::high_resolution_clock::now(); |
389 | |
390 | int fail_ct = 0; |
391 | Array<State> out_states; |
392 | std::vector<std::mt19937> rand_gens; |
393 | rand_gens.reserve(population); |
394 | for (int i = 0; i < population; i++) { |
395 | rand_gens.push_back(std::mt19937(rand_gen())); |
396 | } |
397 | |
398 | std::unordered_set<std::string> explored_state_strs; |
399 | size_t iter = 1; |
400 | size_t unchange_cnt = 0; |
401 | while (static_cast<int>(out_states.size()) < sample_init_min_pop_) { |
402 | std::vector<State> temp_states(population); |
403 | |
404 | // Sample a batch of states randomly |
405 | support::parallel_for(0, population, [this, &temp_states, &sketches, &rand_gens](int index) { |
406 | // Randomly choose a sketch |
407 | State tmp_s = sketches[(rand_gens[index])() % sketches.size()]; |
408 | // Apply random annotation rules one by one |
409 | bool valid = true; |
410 | for (const auto& rule : init_rules) { |
411 | if (rule->Apply(this, &tmp_s, &rand_gens[index]) == |
412 | PopulationGenerationRule::ResultKind::kInvalid) { |
413 | valid = false; |
414 | break; |
415 | } |
416 | } |
417 | if (valid) { |
418 | temp_states[index] = std::move(tmp_s); |
419 | } |
420 | }); |
421 | |
422 | // Filter out the states that were failed to apply initial rules |
423 | Array<State> cand_states; |
424 | for (auto tmp_s : temp_states) { |
425 | if (tmp_s.defined()) { |
426 | cand_states.push_back(std::move(tmp_s)); |
427 | } else { |
428 | fail_ct++; |
429 | } |
430 | } |
431 | |
432 | unchange_cnt++; |
433 | if (!cand_states.empty()) { |
434 | // Run the cost model to make filter out states that failed to extract features. |
435 | // This may happen due to illegal schedules or the schedules that uses too much |
436 | // memory on GPU. |
437 | std::vector<float> pop_scores; |
438 | pop_scores.reserve(cand_states.size()); |
439 | cand_states = search_task->compute_dag.InferBound(cand_states); |
440 | PruneInvalidState(search_task, &cand_states); |
441 | program_cost_model->Predict(search_task, cand_states, &pop_scores); |
442 | |
443 | for (size_t i = 0; i < cand_states.size(); i++) { |
444 | const auto state_str = cand_states[i].ToStr(); |
445 | if (pop_scores[i] > -1e10 && explored_state_strs.count(state_str) == 0) { |
446 | explored_state_strs.insert(state_str); |
447 | out_states.push_back(std::move(cand_states[i])); |
448 | unchange_cnt = 0; // Reset the counter once we found a valid state |
449 | } else { |
450 | fail_ct++; |
451 | } |
452 | } |
453 | } |
454 | |
455 | if (iter % 5 == 0) { |
456 | double duration = std::chrono::duration_cast<std::chrono::duration<double>>( |
457 | std::chrono::high_resolution_clock::now() - tic_begin) |
458 | .count(); |
459 | StdCout(verbose) << "Sample Iter: " << iter << std::fixed << std::setprecision(4) |
460 | << "\t#Pop: " << out_states.size() << "\t#Target: " << sample_init_min_pop_ |
461 | << "\tfail_ct: " << fail_ct << "\tTime elapsed: " << std::fixed |
462 | << std::setprecision(2) << duration << std::endl; |
463 | } |
464 | |
465 | if (unchange_cnt == 5) { |
466 | // Reduce the target size to avoid too-long time in this phase if no valid state was found |
467 | // in the past iterations |
468 | if (sample_init_min_pop_ > 1) { |
469 | sample_init_min_pop_ /= 2; |
470 | StdCout(verbose) << "#Target has been reduced to " << sample_init_min_pop_ |
471 | << " due to too many failures or duplications" << std::endl; |
472 | } |
473 | unchange_cnt = 0; |
474 | } |
475 | iter++; |
476 | } |
477 | |
478 | double duration = std::chrono::duration_cast<std::chrono::duration<double>>( |
479 | std::chrono::high_resolution_clock::now() - tic_begin) |
480 | .count(); |
481 | StdCout(verbose) << "Sample Initial Population\t#s: " << out_states.size() |
482 | << "\tfail_ct: " << fail_ct << "\tTime elapsed: " << std::fixed |
483 | << std::setprecision(2) << duration << std::endl; |
484 | return out_states; |
485 | } |
486 | |
487 | Array<State> SketchPolicyNode::EvolutionarySearch(const Array<State>& init_population, |
488 | int out_size) { |
489 | Array<State> best_states; |
490 | auto tic_begin = std::chrono::high_resolution_clock::now(); |
491 | |
492 | size_t population = GetIntParam(params, SketchParamKey::EvolutionarySearch::population); |
493 | double mutation_prob = GetDoubleParam(params, SketchParamKey::EvolutionarySearch::mutation_prob); |
494 | int num_iters = GetIntParam(params, SketchParamKey::EvolutionarySearch::num_iters); |
495 | |
496 | bool is_cost_model_reasonable = !program_cost_model->IsInstance<RandomModelNode>(); |
497 | if (!is_cost_model_reasonable && num_iters > 2) { |
498 | num_iters = 2; |
499 | StdCout(verbose) << "GA iteration number has been adjusted to " << num_iters |
500 | << " due to random cost model" << std::endl; |
501 | } |
502 | |
503 | // Two ping pong buffers to avoid copy. |
504 | Array<State> states_buf1{init_population}, states_buf2; |
505 | states_buf1.reserve(population); |
506 | states_buf2.reserve(population); |
507 | Array<State>* pnow = &states_buf1; |
508 | Array<State>* pnext = &states_buf2; |
509 | |
510 | // A heap to keep the best states during evolution |
511 | using StateHeapItem = std::pair<State, float>; |
512 | auto cmp = [](const StateHeapItem& left, const StateHeapItem& right) { |
513 | return left.second > right.second; |
514 | }; |
515 | std::vector<StateHeapItem> heap; |
516 | std::unordered_set<std::string> in_heap(measured_states_set_); |
517 | heap.reserve(out_size); |
518 | |
519 | // auxiliary global variables |
520 | std::vector<float> pop_scores; |
521 | std::vector<double> pop_selection_probs; |
522 | float max_score = -1e-10f; |
523 | pop_scores.reserve(population); |
524 | pop_selection_probs.reserve(population); |
525 | std::uniform_real_distribution<> dis(0.0, 1.0); |
526 | |
527 | // mutation rules |
528 | int mutation_success_ct, mutation_fail_ct; |
529 | mutation_success_ct = mutation_fail_ct = 0; |
530 | std::vector<float> rule_weights; |
531 | std::vector<double> rule_selection_probs; |
532 | for (const auto& rule : mutation_rules) { |
533 | rule_weights.push_back(rule->weight); |
534 | } |
535 | ComputePrefixSumProb(rule_weights, &rule_selection_probs); |
536 | |
537 | // Genetic Algorithm |
538 | for (int k = 0; k < num_iters + 1; ++k) { |
539 | // Maintain the heap |
540 | *pnow = search_task->compute_dag.InferBound(*pnow); |
541 | PruneInvalidState(search_task, pnow); |
542 | program_cost_model->Predict(search_task, *pnow, &pop_scores); |
543 | |
544 | for (size_t i = 0; i < pnow->size(); ++i) { |
545 | const State& state = (*pnow)[i]; |
546 | std::string state_str = state.ToStr(); |
547 | |
548 | if (in_heap.count(state_str) == 0) { |
549 | if (static_cast<int>(heap.size()) < out_size) { |
550 | heap.emplace_back((*pnow)[i], pop_scores[i]); |
551 | std::push_heap(heap.begin(), heap.end(), cmp); |
552 | in_heap.insert(state_str); |
553 | } else if (pop_scores[i] > heap.front().second) { |
554 | std::string old_state_str = heap.front().first.ToStr(); |
555 | in_heap.erase(old_state_str); |
556 | in_heap.insert(state_str); |
557 | |
558 | std::pop_heap(heap.begin(), heap.end(), cmp); |
559 | heap.back() = StateHeapItem(state, pop_scores[i]); |
560 | std::push_heap(heap.begin(), heap.end(), cmp); |
561 | } |
562 | if (pop_scores[i] > max_score) { |
563 | max_score = pop_scores[i]; |
564 | } |
565 | } |
566 | } |
567 | |
568 | // Print statistical information |
569 | if (k % 5 == 0 || k == num_iters) { |
570 | StdCout(verbose) << "GA Iter: " << k; |
571 | if (!heap.empty()) { |
572 | StdCout(verbose) << std::fixed << std::setprecision(4) << "\tMax score: " << max_score |
573 | << std::fixed << std::setprecision(4) |
574 | << "\tMin score: " << heap.front().second; |
575 | } else { |
576 | StdCout(verbose) << "\tMax score: N/A\tMin score: N/A" ; |
577 | } |
578 | StdCout(verbose) << "\t#Pop: " << heap.size() << "\t#M+: " << mutation_success_ct / (k + 1) |
579 | << "\t#M-: " << mutation_fail_ct / (k + 1) << std::endl; |
580 | } |
581 | if (k == num_iters) { |
582 | break; |
583 | } |
584 | |
585 | // Compute selection probability |
586 | ComputePrefixSumProb(pop_scores, &pop_selection_probs); |
587 | |
588 | // TODO(merrymercy, comaniac): add crossover. |
589 | |
590 | // Do mutation |
591 | while (pnext->size() < population) { |
592 | State tmp_s = (*pnow)[RandomChoose(pop_selection_probs, &rand_gen)]; |
593 | |
594 | if (dis(rand_gen) < mutation_prob) { |
595 | const auto& rule = mutation_rules[RandomChoose(rule_selection_probs, &rand_gen)]; |
596 | if (rule->Apply(this, &tmp_s, &rand_gen) == PopulationGenerationRule::ResultKind::kValid) { |
597 | pnext->push_back(std::move(tmp_s)); |
598 | mutation_success_ct++; |
599 | } else { |
600 | mutation_fail_ct++; |
601 | } |
602 | } else { |
603 | pnext->push_back(std::move(tmp_s)); |
604 | } |
605 | } |
606 | |
607 | std::swap(pnext, pnow); |
608 | pnext->clear(); |
609 | } |
610 | |
611 | // Copy best states in the heap to out_states |
612 | std::sort(heap.begin(), heap.end(), cmp); |
613 | for (auto& item : heap) { |
614 | best_states.push_back(std::move(item.first)); |
615 | } |
616 | |
617 | double duration = std::chrono::duration_cast<std::chrono::duration<double>>( |
618 | std::chrono::high_resolution_clock::now() - tic_begin) |
619 | .count(); |
620 | StdCout(verbose) << "EvolutionarySearch\t\t#s: " << best_states.size() |
621 | << "\tTime elapsed: " << std::fixed << std::setprecision(2) << duration |
622 | << std::endl; |
623 | return best_states; |
624 | } |
625 | |
626 | Array<MeasureInput> SketchPolicyNode::PickStatesWithEpsGreedy(const Array<State>& best_states, |
627 | const Array<State>& random_states, |
628 | int remaining_n_trials) { |
629 | int num_random = |
630 | static_cast<int>(GetDoubleParam(params, SketchParamKey::eps_greedy) * num_measure_per_iter_); |
631 | int num_good = num_measure_per_iter_ - num_random; |
632 | |
633 | Array<MeasureInput> inputs; |
634 | size_t offset_best = 0, offset_random = 0; |
635 | |
636 | while (static_cast<int>(inputs.size()) < std::min(num_measure_per_iter_, remaining_n_trials)) { |
637 | State state; |
638 | |
639 | bool has_best = offset_best < best_states.size(); |
640 | bool has_random = offset_random < random_states.size(); |
641 | |
642 | if (static_cast<int>(inputs.size()) < num_good) { |
643 | // prefer best states |
644 | if (has_best) { |
645 | state = best_states[offset_best++]; |
646 | } else if (has_random) { |
647 | state = random_states[offset_random++]; |
648 | } else { |
649 | break; |
650 | } |
651 | } else { |
652 | // prefer random states |
653 | if (has_random) { |
654 | state = random_states[offset_random++]; |
655 | } else if (has_best) { |
656 | state = best_states[offset_best++]; |
657 | } else { |
658 | break; |
659 | } |
660 | } |
661 | |
662 | // Check if it has already been measured |
663 | std::string state_str = state.ToStr(); |
664 | if (!measured_states_set_.count(state_str)) { |
665 | measured_states_set_.insert(std::move(state_str)); |
666 | measured_states_vector_.push_back(state); |
667 | inputs.push_back(MeasureInput(search_task, state)); |
668 | } |
669 | } |
670 | |
671 | return inputs; |
672 | } |
673 | |
674 | /********** PreloadCustomSketchRule **********/ |
675 | TVM_REGISTER_OBJECT_TYPE(PreloadCustomSketchRuleNode); |
676 | |
677 | PreloadCustomSketchRule::PreloadCustomSketchRule(PackedFunc meet_condition_func, |
678 | PackedFunc apply_func, String rule_name) { |
679 | auto node = make_object<PreloadCustomSketchRuleNode>(); |
680 | node->meet_condition_func = std::move(meet_condition_func); |
681 | node->apply_func = std::move(apply_func); |
682 | node->rule_name = std::move(rule_name); |
683 | data_ = std::move(node); |
684 | } |
685 | |
686 | void PreloadCustomSketchRuleNode::Callback(SearchPolicyNode* policy) { |
687 | CHECK(policy->IsInstance<SketchPolicyNode>()); |
688 | auto sketch_policy = dynamic_cast<SketchPolicyNode*>(policy); |
689 | sketch_policy->sketch_rules.push_back( |
690 | new RuleCustomSketch(meet_condition_func, apply_func, rule_name)); |
691 | StdCout(policy->verbose) << "Custom sketch rule \"" << rule_name << "\" added." << std::endl; |
692 | } |
693 | |
694 | TVM_REGISTER_GLOBAL("auto_scheduler.SketchPolicy" ) |
695 | .set_body_typed([](SearchTask task, CostModel program_cost_model, Map<String, ObjectRef> params, |
696 | int seed, int verbose, |
697 | Optional<Array<SearchCallback>> init_search_callbacks) { |
698 | return SketchPolicy(task, program_cost_model, params, seed, verbose, init_search_callbacks); |
699 | }); |
700 | |
701 | TVM_REGISTER_GLOBAL("auto_scheduler.SketchPolicyGenerateSketches" ) |
702 | .set_body_typed([](SketchPolicy policy) { return policy->GenerateSketches(); }); |
703 | |
704 | TVM_REGISTER_GLOBAL("auto_scheduler.SketchPolicySampleInitialPopulation" ) |
705 | .set_body_typed([](SketchPolicy policy) { |
706 | const Array<State>& sketches = policy->GenerateSketches(); |
707 | |
708 | Array<State> init_population = policy->SampleInitPopulation(sketches); |
709 | return init_population; |
710 | }); |
711 | |
712 | TVM_REGISTER_GLOBAL("auto_scheduler.SketchPolicyEvolutionarySearch" ) |
713 | .set_body_typed([](SketchPolicy policy, Array<State> init_population, int out_size) { |
714 | Array<State> states = policy->EvolutionarySearch(init_population, out_size); |
715 | return states; |
716 | }); |
717 | |
718 | TVM_REGISTER_GLOBAL("auto_scheduler.PrintTitle" ).set_body_typed([](std::string title) { |
719 | PrintTitle(title, 1); |
720 | }); |
721 | |
722 | TVM_REGISTER_GLOBAL("auto_scheduler.PreloadCustomSketchRule" ) |
723 | .set_body_typed([](PackedFunc meet_condition_func, PackedFunc apply_func, String rule_name) { |
724 | return PreloadCustomSketchRule(meet_condition_func, apply_func, rule_name); |
725 | }); |
726 | |
727 | } // namespace auto_scheduler |
728 | } // namespace tvm |
729 | |