1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/auto_scheduler/search_policy.h |
22 | * \brief The base class of search policies, including the abstract definition of search policy and |
23 | * other supporting data structures. |
24 | * |
25 | * \note How to add a new search policy. |
26 | * In design, there's no need for users to implement their own search policy, our formal search |
27 | * policy(will be brought later) should be enough to cover most use cases. Meanwhile, a custom rule |
28 | * mechanism will be provided to enable user-defined template search to serve the same functionality |
29 | * as the current AutoTVM template. |
30 | * |
31 | * This guide is for advanced uses who have special requirements. |
32 | * 1. The only function that must be implemented is Search(), which takes a task as input and |
33 | * returns the best states found. |
34 | * 2. Information about the compute declaration of ops/subgraphs can be acquired from SearchTask. |
35 | * This structure also contains some information about the target device. (e.g. knowing the width |
36 | * of the device vector unit, we can limit the max vectorize size during schedule search) |
37 | * 3. SearchCallback provides more flexibility to do extra affairs before/after the search process. |
38 | * 4. ProgramMeasurer provides a simple but useful api to help check the performance of states got |
39 | * during the search process. |
40 | */ |
41 | |
42 | #ifndef TVM_AUTO_SCHEDULER_SEARCH_POLICY_H_ |
43 | #define TVM_AUTO_SCHEDULER_SEARCH_POLICY_H_ |
44 | |
45 | #include <tvm/auto_scheduler/measure.h> |
46 | #include <tvm/auto_scheduler/search_task.h> |
47 | #include <tvm/node/node.h> |
48 | |
49 | #include <string> |
50 | #include <unordered_set> |
51 | #include <utility> |
52 | #include <vector> |
53 | |
54 | namespace tvm { |
55 | namespace auto_scheduler { |
56 | |
57 | class ProgramMeasurer; |
58 | class SearchPolicyNode; |
59 | |
60 | /*! |
61 | * \brief Callback function to be called by the search process. |
62 | * This interface allows to do extra initializations before schedule search or extra |
63 | * check during/after the schedule search. |
64 | */ |
65 | class SearchCallbackNode : public Object { |
66 | public: |
67 | /*! |
68 | * \brief Run the registered callback function. |
69 | * \param policy A pointer to a SearchPolicyNode. |
70 | */ |
71 | virtual void Callback(SearchPolicyNode* policy) = 0; |
72 | |
73 | static constexpr const char* _type_key = "auto_scheduler.SearchCallback" ; |
74 | TVM_DECLARE_BASE_OBJECT_INFO(SearchCallbackNode, Object); |
75 | }; |
76 | |
77 | /*! |
78 | * \brief Managed reference to SearchCallbackNode. |
79 | * \sa SearchCallbackNode |
80 | */ |
81 | class SearchCallback : public ObjectRef { |
82 | public: |
83 | TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchCallback, ObjectRef, SearchCallbackNode); |
84 | }; |
85 | |
86 | /*! \brief Preload measured states from a log file. |
87 | * This can resume the state of the search policy */ |
88 | class PreloadMeasuredStatesNode : public SearchCallbackNode { |
89 | public: |
90 | /*! \brief The name of the record log file. */ |
91 | String filename; |
92 | |
93 | void Callback(SearchPolicyNode* policy) final; |
94 | |
95 | static constexpr const char* _type_key = "auto_scheduler.PreloadMeasuredStates" ; |
96 | TVM_DECLARE_FINAL_OBJECT_INFO(PreloadMeasuredStatesNode, SearchCallbackNode); |
97 | }; |
98 | |
99 | /*! |
100 | * \brief Managed reference to PreloadMeasuredStatesNode. |
101 | * \sa PreloadMeasuredStatesNode |
102 | */ |
103 | class PreloadMeasuredStates : public SearchCallback { |
104 | public: |
105 | /*! |
106 | * \brief The constructor. |
107 | * \param filename The name of the record log file. |
108 | */ |
109 | explicit PreloadMeasuredStates(String filename); |
110 | |
111 | TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PreloadMeasuredStates, SearchCallback, |
112 | PreloadMeasuredStatesNode); |
113 | }; |
114 | |
115 | /*! \brief Attribute keys of ops used for SearchPolicy. */ |
116 | struct SearchPolicyKey { |
117 | /*! \brief Always apply unroll to the inner most iterator of the specificed iterators. */ |
118 | static constexpr const char* always_unroll_inner = "auto_scheduler_always_unroll_inner" ; |
119 | /*! \brief The specified iterators will be placed in the inner most tile without split. */ |
120 | static constexpr const char* no_split_at_inner = "auto_scheduler_no_split_at_inner" ; |
121 | /*! \brief The specified iterators are indices of const tensors in "fake reduction". */ |
122 | static constexpr const char* simplify_const_tensor_indices = |
123 | "auto_scheduler_simplify_const_tensor_indices" ; |
124 | }; |
125 | |
126 | /*! |
127 | * \brief The base class of search policies. |
128 | */ |
129 | class SearchPolicyNode : public Object { |
130 | public: |
131 | /*! \brief The current search task. */ |
132 | SearchTask search_task; |
133 | /*! |
134 | * \brief Verbose level to control the screen output during schedule search. |
135 | * 0 for silent, 1 to output state & measure information during search process. |
136 | */ |
137 | int verbose; |
138 | |
139 | void VisitAttrs(AttrVisitor* v) { |
140 | v->Visit("search_task" , &search_task); |
141 | v->Visit("verbose" , &verbose); |
142 | } |
143 | |
144 | /*! |
145 | * \brief Do schedule search for a task. Takes the SearchTask as input and returns the best state |
146 | * found during the search. |
147 | * \param num_measure_trials The number of total measurement trials. |
148 | * \param early_stopping Stops the tuning early if no improvement after n measurements. |
149 | * \param num_measures_per_round The number of programs to be measured at each search round. |
150 | * \param measurer A ProgramMeasurer to build and measure programs |
151 | * \return The best state found. |
152 | */ |
153 | virtual State Search(int num_measure_trials, int early_stopping, int num_measures_per_round, |
154 | ProgramMeasurer measurer) = 0; |
155 | |
156 | /*! |
157 | * \brief Continue the search by doing an additional search round. |
158 | * \param num_measure The number of measurements |
159 | * \param measurer The measurer to measure programs |
160 | * \return The measurement records for measurements in this search round |
161 | */ |
162 | virtual std::pair<Array<MeasureInput>, Array<MeasureResult>> ContinueSearchOneRound( |
163 | int num_measure, ProgramMeasurer measurer) = 0; |
164 | |
165 | /*! |
166 | * \brief Preload measured states from a log file to resume the state of the search policy. |
167 | * \param log_file The name of the record log file. |
168 | */ |
169 | void PreloadMeasuredStates(const String& log_file); |
170 | |
171 | /*! |
172 | * \brief Call SearchCallback with the current SearchPolicyNode |
173 | * \param callbacks SearchCallback to be called. |
174 | */ |
175 | void RunCallbacks(const Array<SearchCallback>& callbacks); |
176 | |
177 | static constexpr const char* _type_key = "auto_scheduler.SearchPolicy" ; |
178 | TVM_DECLARE_BASE_OBJECT_INFO(SearchPolicyNode, Object); |
179 | |
180 | protected: |
181 | /*! |
182 | * \brief The set of already measured states. |
183 | * We store the string format of a state for redundancy check. This is used to make sure a |
184 | * measured state will never be measured again. |
185 | */ |
186 | std::unordered_set<std::string> measured_states_set_; |
187 | /*! \brief The array of already measured states. |
188 | * The good states can be used as the initial population in evolutionary search. */ |
189 | std::vector<State> measured_states_vector_; |
190 | /*! \brief The throughputs of already measured states */ |
191 | std::vector<float> measured_states_throughputs_; |
192 | }; |
193 | |
194 | /*! |
195 | * \brief Managed reference to SearchPolicyNode. |
196 | * \sa SearchPolicyNode |
197 | */ |
198 | class SearchPolicy : public ObjectRef { |
199 | public: |
200 | TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchPolicy, ObjectRef, SearchPolicyNode); |
201 | }; |
202 | |
203 | } // namespace auto_scheduler |
204 | } // namespace tvm |
205 | |
206 | #endif // TVM_AUTO_SCHEDULER_SEARCH_POLICY_H_ |
207 | |