1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file auto_scheduler/search_policy/sketch_policy.h |
22 | * \brief This search policy constructs a search space according to the compute declaration. |
23 | * It then randomly samples programs from the search space and uses evolutionary search with a |
24 | * learned cost model to fine tune the sampled programs. |
25 | * The final optimized programs are sent to actual hardware for measurement. |
26 | * The above process is repeated until the auto-scheduler runs out of time budget. |
27 | * |
28 | * Reference: |
29 | * L. Zheng, C. Jia, M. Sun, Z. Wu, C. Yu, et al. "Ansor : Generating High-Performance Tensor |
30 | * Programs for Deep Learning." (OSDI 2020). |
31 | */ |
32 | |
33 | #ifndef TVM_AUTO_SCHEDULER_SEARCH_POLICY_SKETCH_POLICY_H_ |
34 | #define TVM_AUTO_SCHEDULER_SEARCH_POLICY_SKETCH_POLICY_H_ |
35 | |
36 | #include <tvm/auto_scheduler/cost_model.h> |
37 | #include <tvm/auto_scheduler/search_policy.h> |
38 | |
39 | #include <memory> |
40 | #include <set> |
41 | #include <string> |
42 | #include <unordered_set> |
43 | #include <utility> |
44 | #include <vector> |
45 | |
46 | #include "sketch_policy_rules.h" |
47 | #include "utils.h" |
48 | |
49 | namespace tvm { |
50 | namespace auto_scheduler { |
51 | |
52 | /*! \brief String keys used in parameter map of SketchPolicy. */ |
53 | struct SketchParamKey { |
54 | /*! \brief Always allocate this percentage of measurements to random sampled states. */ |
55 | static constexpr const char* eps_greedy = "eps_greedy" ; |
56 | /*! \brief Retry several times if SearchOneRound gets no valid state. */ |
57 | static constexpr const char* empty_retry_count = "retry_search_one_round_on_empty" ; |
58 | |
59 | struct SampleInitPopulation { |
60 | /*! \brief The minimal size of valid population in the initial sampling. */ |
61 | static constexpr const char* min_population = "sample_init_min_population" ; |
62 | /*! \brief The maximum percentage of measured states in the initial sampling. */ |
63 | static constexpr const char* use_measured_ratio = "sample_init_use_measured_ratio" ; |
64 | }; |
65 | |
66 | struct EvolutionarySearch { |
67 | /*! \brief The population size of evolutionary search. */ |
68 | static constexpr const char* population = "evolutionary_search_population" ; |
69 | /*! \brief The number of iterations performed by generic algorithm.*/ |
70 | static constexpr const char* num_iters = "evolutionary_search_num_iters" ; |
71 | /*! \brief The mutation probability.*/ |
72 | static constexpr const char* mutation_prob = "evolutionary_search_mutation_prob" ; |
73 | }; |
74 | |
75 | struct MultiLevelTiling { |
76 | /*! \brief The structure of multi-level tiling for CPU. */ |
77 | static constexpr const char* cpu_structure = "cpu_multi_level_tiling_structure" ; |
78 | /*! \brief The structure of multi-level tiling for GPU. */ |
79 | static constexpr const char* gpu_structure = "gpu_multi_level_tiling_structure" ; |
80 | }; |
81 | |
82 | /*! \brief The max inner most split factor. */ |
83 | static constexpr const char* max_innermost_split_factor = "max_innermost_split_factor" ; |
84 | /*! \brief The max vectorize size. */ |
85 | static constexpr const char* max_vectorize_size = "max_vectorize_size" ; |
86 | /*! \brief Whether disable compute location changing. */ |
87 | static constexpr const char* disable_change_compute_location = "disable_change_compute_location" ; |
88 | }; |
89 | |
90 | class SketchPolicy; |
91 | |
92 | /*! |
93 | * \brief The search policy that searches in a hierarchical search space defined by sketches. |
94 | * The policy randomly samples programs from the space defined by sketches |
95 | * and use evolutionary search to fine-tune them. |
96 | */ |
97 | class SketchPolicyNode : public SearchPolicyNode { |
98 | public: |
99 | /*! \brief The cost model to estimate the complete schedules. */ |
100 | CostModel program_cost_model; |
101 | /*! \brief The parameters map for this search policy. */ |
102 | Map<String, ObjectRef> params; |
103 | /*! \brief The rules to generate sketches. */ |
104 | std::vector<SketchGenerationRule*> sketch_rules; |
105 | /*! \brief The rules to generate initial population. */ |
106 | std::vector<PopulationGenerationRule*> init_rules; |
107 | /*! \brief The rules to mutate states in the evolutionary search. */ |
108 | std::vector<std::shared_ptr<PopulationMutationRule>> mutation_rules; |
109 | /*! \brief Random generator. */ |
110 | std::mt19937 rand_gen; |
111 | /*! \brief Memorize split space for Split. */ |
112 | SplitFactorizationMemo split_memo; |
113 | |
114 | State Search(int num_measure_trials, int early_stopping, int num_measures_per_round, |
115 | ProgramMeasurer measurer) final; |
116 | |
117 | std::pair<Array<MeasureInput>, Array<MeasureResult>> ContinueSearchOneRound( |
118 | int num_measure, ProgramMeasurer measurer) final; |
119 | |
120 | /*! |
121 | * \brief Generate sketches. |
122 | * \return The generated sketches(states). |
123 | */ |
124 | Array<State> GenerateSketches(); |
125 | |
126 | /*! |
127 | * \brief Sample the init population. |
128 | * \param sketches The initial sketches for the sampled population |
129 | * \return The generated states (the initial population). |
130 | */ |
131 | Array<State> SampleInitPopulation(const Array<State>& sketches); |
132 | |
133 | /*! |
134 | * \brief Perform evolutionary search. |
135 | * \param init_populations The states generated from init population. |
136 | * \param out_size The number of expected output states. |
137 | * \return The generated states after evolutionary search. |
138 | */ |
139 | Array<State> EvolutionarySearch(const Array<State>& init_populations, int out_size); |
140 | |
141 | static constexpr const char* _type_key = "auto_scheduler.SketchPolicy" ; |
142 | |
143 | TVM_DECLARE_FINAL_OBJECT_INFO(SketchPolicyNode, SearchPolicyNode); |
144 | |
145 | private: |
146 | /*! |
147 | * \brief Run one round of the search pipeline. |
148 | * \param num_random_states Number of states that are picked randomly, this is used for |
149 | * eps-greedy policy. |
150 | * \param random_states The picked random states, used as one of the output of this function. |
151 | * \return The best several states generated in this search round. |
152 | */ |
153 | Array<State> SearchOneRound(int num_random_states, Array<State>* random_states = nullptr); |
154 | |
155 | /*! |
156 | * \brief Pick states from best states and random states with eps-greedy policy. |
157 | * \param best_states States picked by cost model. |
158 | * \param random_states States picked randomly. |
159 | * \param remaining_n_trials The remaining number of states need to be generated. |
160 | * \return The generated states to be measured, wrapped in MeasureInput. |
161 | */ |
162 | Array<MeasureInput> PickStatesWithEpsGreedy(const Array<State>& best_states, |
163 | const Array<State>& random_states, |
164 | int remaining_n_trials); |
165 | |
166 | /*! \brief The number of states to measure per iteration. */ |
167 | int num_measure_per_iter_; |
168 | |
169 | /*! \brief The cached sketches */ |
170 | Array<State> sketch_cache_; |
171 | |
172 | /*! \brief The minimul output population of SampleInitPopulation */ |
173 | int sample_init_min_pop_; |
174 | |
175 | friend class SketchPolicy; |
176 | }; |
177 | |
178 | /*! |
179 | * \brief Managed reference to SketchPolicyNode. |
180 | * \sa SketchPolicyNode |
181 | */ |
182 | class SketchPolicy : public SearchPolicy { |
183 | public: |
184 | /*! |
185 | * \brief The constructor. |
186 | * \param task The SearchTask for the computation declaration. |
187 | * \param program_cost_model The cost model for complete programs. |
188 | * \param params The parameters map for this search process. |
189 | * \param seed The random seed of this search process. |
190 | * \param verbose Verbose level. 0 for silent, 1 to output information during schedule |
191 | * search. |
192 | * \param init_search_callbacks SearchCallback to be called before schedule search. |
193 | */ |
194 | SketchPolicy(SearchTask task, CostModel program_cost_model, Map<String, ObjectRef> params, |
195 | int seed, int verbose, Optional<Array<SearchCallback>> init_search_callbacks); |
196 | |
197 | TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SketchPolicy, SearchPolicy, SketchPolicyNode); |
198 | }; |
199 | |
200 | /*! \brief Pre-search callback function to load custom rules for sketch generation */ |
201 | class PreloadCustomSketchRuleNode : public SearchCallbackNode { |
202 | public: |
203 | /*! \brief The condition check function of this rule. */ |
204 | PackedFunc meet_condition_func; |
205 | /*! \brief The apply function of this rule. */ |
206 | PackedFunc apply_func; |
207 | /*! \brief The name of this rule. */ |
208 | String rule_name; |
209 | |
210 | void Callback(SearchPolicyNode* policy) final; |
211 | |
212 | static constexpr const char* _type_key = "auto_scheduler.PreloadCustomSketchRule" ; |
213 | TVM_DECLARE_FINAL_OBJECT_INFO(PreloadCustomSketchRuleNode, SearchCallbackNode); |
214 | }; |
215 | |
216 | /*! |
217 | * \brief Managed reference to PreloadCustomSketchRuleNode. |
218 | * \sa PreloadCustomSketchRuleNode |
219 | */ |
220 | class PreloadCustomSketchRule : public SearchCallback { |
221 | public: |
222 | /*! |
223 | * \brief The constructor. |
224 | * \param meet_condition_func The condition check function of this rule. |
225 | * \param apply_func The apply function of this rule. |
226 | * \param rule_name The name of this rule. |
227 | */ |
228 | PreloadCustomSketchRule(PackedFunc meet_condition_func, PackedFunc apply_func, String rule_name); |
229 | |
230 | TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PreloadCustomSketchRule, SearchCallback, |
231 | PreloadCustomSketchRuleNode); |
232 | }; |
233 | |
234 | } // namespace auto_scheduler |
235 | } // namespace tvm |
236 | |
237 | #endif // TVM_AUTO_SCHEDULER_SEARCH_POLICY_SKETCH_POLICY_H_ |
238 | |