1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file auto_scheduler/search_policy/sketch_policy.h
22 * \brief This search policy constructs a search space according to the compute declaration.
23 * It then randomly samples programs from the search space and uses evolutionary search with a
24 * learned cost model to fine tune the sampled programs.
25 * The final optimized programs are sent to actual hardware for measurement.
26 * The above process is repeated until the auto-scheduler runs out of time budget.
27 *
28 * Reference:
29 * L. Zheng, C. Jia, M. Sun, Z. Wu, C. Yu, et al. "Ansor : Generating High-Performance Tensor
30 * Programs for Deep Learning." (OSDI 2020).
31 */
32
33#ifndef TVM_AUTO_SCHEDULER_SEARCH_POLICY_SKETCH_POLICY_H_
34#define TVM_AUTO_SCHEDULER_SEARCH_POLICY_SKETCH_POLICY_H_
35
36#include <tvm/auto_scheduler/cost_model.h>
37#include <tvm/auto_scheduler/search_policy.h>
38
39#include <memory>
40#include <set>
41#include <string>
42#include <unordered_set>
43#include <utility>
44#include <vector>
45
46#include "sketch_policy_rules.h"
47#include "utils.h"
48
49namespace tvm {
50namespace auto_scheduler {
51
52/*! \brief String keys used in parameter map of SketchPolicy. */
53struct SketchParamKey {
54 /*! \brief Always allocate this percentage of measurements to random sampled states. */
55 static constexpr const char* eps_greedy = "eps_greedy";
56 /*! \brief Retry several times if SearchOneRound gets no valid state. */
57 static constexpr const char* empty_retry_count = "retry_search_one_round_on_empty";
58
59 struct SampleInitPopulation {
60 /*! \brief The minimal size of valid population in the initial sampling. */
61 static constexpr const char* min_population = "sample_init_min_population";
62 /*! \brief The maximum percentage of measured states in the initial sampling. */
63 static constexpr const char* use_measured_ratio = "sample_init_use_measured_ratio";
64 };
65
66 struct EvolutionarySearch {
67 /*! \brief The population size of evolutionary search. */
68 static constexpr const char* population = "evolutionary_search_population";
69 /*! \brief The number of iterations performed by generic algorithm.*/
70 static constexpr const char* num_iters = "evolutionary_search_num_iters";
71 /*! \brief The mutation probability.*/
72 static constexpr const char* mutation_prob = "evolutionary_search_mutation_prob";
73 };
74
75 struct MultiLevelTiling {
76 /*! \brief The structure of multi-level tiling for CPU. */
77 static constexpr const char* cpu_structure = "cpu_multi_level_tiling_structure";
78 /*! \brief The structure of multi-level tiling for GPU. */
79 static constexpr const char* gpu_structure = "gpu_multi_level_tiling_structure";
80 };
81
82 /*! \brief The max inner most split factor. */
83 static constexpr const char* max_innermost_split_factor = "max_innermost_split_factor";
84 /*! \brief The max vectorize size. */
85 static constexpr const char* max_vectorize_size = "max_vectorize_size";
86 /*! \brief Whether disable compute location changing. */
87 static constexpr const char* disable_change_compute_location = "disable_change_compute_location";
88};
89
90class SketchPolicy;
91
92/*!
93 * \brief The search policy that searches in a hierarchical search space defined by sketches.
94 * The policy randomly samples programs from the space defined by sketches
95 * and use evolutionary search to fine-tune them.
96 */
97class SketchPolicyNode : public SearchPolicyNode {
98 public:
99 /*! \brief The cost model to estimate the complete schedules. */
100 CostModel program_cost_model;
101 /*! \brief The parameters map for this search policy. */
102 Map<String, ObjectRef> params;
103 /*! \brief The rules to generate sketches. */
104 std::vector<SketchGenerationRule*> sketch_rules;
105 /*! \brief The rules to generate initial population. */
106 std::vector<PopulationGenerationRule*> init_rules;
107 /*! \brief The rules to mutate states in the evolutionary search. */
108 std::vector<std::shared_ptr<PopulationMutationRule>> mutation_rules;
109 /*! \brief Random generator. */
110 std::mt19937 rand_gen;
111 /*! \brief Memorize split space for Split. */
112 SplitFactorizationMemo split_memo;
113
114 State Search(int num_measure_trials, int early_stopping, int num_measures_per_round,
115 ProgramMeasurer measurer) final;
116
117 std::pair<Array<MeasureInput>, Array<MeasureResult>> ContinueSearchOneRound(
118 int num_measure, ProgramMeasurer measurer) final;
119
120 /*!
121 * \brief Generate sketches.
122 * \return The generated sketches(states).
123 */
124 Array<State> GenerateSketches();
125
126 /*!
127 * \brief Sample the init population.
128 * \param sketches The initial sketches for the sampled population
129 * \return The generated states (the initial population).
130 */
131 Array<State> SampleInitPopulation(const Array<State>& sketches);
132
133 /*!
134 * \brief Perform evolutionary search.
135 * \param init_populations The states generated from init population.
136 * \param out_size The number of expected output states.
137 * \return The generated states after evolutionary search.
138 */
139 Array<State> EvolutionarySearch(const Array<State>& init_populations, int out_size);
140
141 static constexpr const char* _type_key = "auto_scheduler.SketchPolicy";
142
143 TVM_DECLARE_FINAL_OBJECT_INFO(SketchPolicyNode, SearchPolicyNode);
144
145 private:
146 /*!
147 * \brief Run one round of the search pipeline.
148 * \param num_random_states Number of states that are picked randomly, this is used for
149 * eps-greedy policy.
150 * \param random_states The picked random states, used as one of the output of this function.
151 * \return The best several states generated in this search round.
152 */
153 Array<State> SearchOneRound(int num_random_states, Array<State>* random_states = nullptr);
154
155 /*!
156 * \brief Pick states from best states and random states with eps-greedy policy.
157 * \param best_states States picked by cost model.
158 * \param random_states States picked randomly.
159 * \param remaining_n_trials The remaining number of states need to be generated.
160 * \return The generated states to be measured, wrapped in MeasureInput.
161 */
162 Array<MeasureInput> PickStatesWithEpsGreedy(const Array<State>& best_states,
163 const Array<State>& random_states,
164 int remaining_n_trials);
165
166 /*! \brief The number of states to measure per iteration. */
167 int num_measure_per_iter_;
168
169 /*! \brief The cached sketches */
170 Array<State> sketch_cache_;
171
172 /*! \brief The minimul output population of SampleInitPopulation */
173 int sample_init_min_pop_;
174
175 friend class SketchPolicy;
176};
177
178/*!
179 * \brief Managed reference to SketchPolicyNode.
180 * \sa SketchPolicyNode
181 */
182class SketchPolicy : public SearchPolicy {
183 public:
184 /*!
185 * \brief The constructor.
186 * \param task The SearchTask for the computation declaration.
187 * \param program_cost_model The cost model for complete programs.
188 * \param params The parameters map for this search process.
189 * \param seed The random seed of this search process.
190 * \param verbose Verbose level. 0 for silent, 1 to output information during schedule
191 * search.
192 * \param init_search_callbacks SearchCallback to be called before schedule search.
193 */
194 SketchPolicy(SearchTask task, CostModel program_cost_model, Map<String, ObjectRef> params,
195 int seed, int verbose, Optional<Array<SearchCallback>> init_search_callbacks);
196
197 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SketchPolicy, SearchPolicy, SketchPolicyNode);
198};
199
200/*! \brief Pre-search callback function to load custom rules for sketch generation */
201class PreloadCustomSketchRuleNode : public SearchCallbackNode {
202 public:
203 /*! \brief The condition check function of this rule. */
204 PackedFunc meet_condition_func;
205 /*! \brief The apply function of this rule. */
206 PackedFunc apply_func;
207 /*! \brief The name of this rule. */
208 String rule_name;
209
210 void Callback(SearchPolicyNode* policy) final;
211
212 static constexpr const char* _type_key = "auto_scheduler.PreloadCustomSketchRule";
213 TVM_DECLARE_FINAL_OBJECT_INFO(PreloadCustomSketchRuleNode, SearchCallbackNode);
214};
215
216/*!
217 * \brief Managed reference to PreloadCustomSketchRuleNode.
218 * \sa PreloadCustomSketchRuleNode
219 */
220class PreloadCustomSketchRule : public SearchCallback {
221 public:
222 /*!
223 * \brief The constructor.
224 * \param meet_condition_func The condition check function of this rule.
225 * \param apply_func The apply function of this rule.
226 * \param rule_name The name of this rule.
227 */
228 PreloadCustomSketchRule(PackedFunc meet_condition_func, PackedFunc apply_func, String rule_name);
229
230 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PreloadCustomSketchRule, SearchCallback,
231 PreloadCustomSketchRuleNode);
232};
233
234} // namespace auto_scheduler
235} // namespace tvm
236
237#endif // TVM_AUTO_SCHEDULER_SEARCH_POLICY_SKETCH_POLICY_H_
238