1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/auto_scheduler/search_policy.h
22 * \brief The base class of search policies, including the abstract definition of search policy and
23 * other supporting data structures.
24 *
25 * \note How to add a new search policy.
26 * In design, there's no need for users to implement their own search policy, our formal search
27 * policy(will be brought later) should be enough to cover most use cases. Meanwhile, a custom rule
28 * mechanism will be provided to enable user-defined template search to serve the same functionality
29 * as the current AutoTVM template.
30 *
31 * This guide is for advanced uses who have special requirements.
32 * 1. The only function that must be implemented is Search(), which takes a task as input and
33 * returns the best states found.
34 * 2. Information about the compute declaration of ops/subgraphs can be acquired from SearchTask.
35 * This structure also contains some information about the target device. (e.g. knowing the width
36 * of the device vector unit, we can limit the max vectorize size during schedule search)
37 * 3. SearchCallback provides more flexibility to do extra affairs before/after the search process.
38 * 4. ProgramMeasurer provides a simple but useful api to help check the performance of states got
39 * during the search process.
40 */
41
42#ifndef TVM_AUTO_SCHEDULER_SEARCH_POLICY_H_
43#define TVM_AUTO_SCHEDULER_SEARCH_POLICY_H_
44
45#include <tvm/auto_scheduler/measure.h>
46#include <tvm/auto_scheduler/search_task.h>
47#include <tvm/node/node.h>
48
49#include <string>
50#include <unordered_set>
51#include <utility>
52#include <vector>
53
54namespace tvm {
55namespace auto_scheduler {
56
57class ProgramMeasurer;
58class SearchPolicyNode;
59
60/*!
61 * \brief Callback function to be called by the search process.
62 * This interface allows to do extra initializations before schedule search or extra
63 * check during/after the schedule search.
64 */
65class SearchCallbackNode : public Object {
66 public:
67 /*!
68 * \brief Run the registered callback function.
69 * \param policy A pointer to a SearchPolicyNode.
70 */
71 virtual void Callback(SearchPolicyNode* policy) = 0;
72
73 static constexpr const char* _type_key = "auto_scheduler.SearchCallback";
74 TVM_DECLARE_BASE_OBJECT_INFO(SearchCallbackNode, Object);
75};
76
77/*!
78 * \brief Managed reference to SearchCallbackNode.
79 * \sa SearchCallbackNode
80 */
81class SearchCallback : public ObjectRef {
82 public:
83 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchCallback, ObjectRef, SearchCallbackNode);
84};
85
86/*! \brief Preload measured states from a log file.
87 * This can resume the state of the search policy */
88class PreloadMeasuredStatesNode : public SearchCallbackNode {
89 public:
90 /*! \brief The name of the record log file. */
91 String filename;
92
93 void Callback(SearchPolicyNode* policy) final;
94
95 static constexpr const char* _type_key = "auto_scheduler.PreloadMeasuredStates";
96 TVM_DECLARE_FINAL_OBJECT_INFO(PreloadMeasuredStatesNode, SearchCallbackNode);
97};
98
99/*!
100 * \brief Managed reference to PreloadMeasuredStatesNode.
101 * \sa PreloadMeasuredStatesNode
102 */
103class PreloadMeasuredStates : public SearchCallback {
104 public:
105 /*!
106 * \brief The constructor.
107 * \param filename The name of the record log file.
108 */
109 explicit PreloadMeasuredStates(String filename);
110
111 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PreloadMeasuredStates, SearchCallback,
112 PreloadMeasuredStatesNode);
113};
114
115/*! \brief Attribute keys of ops used for SearchPolicy. */
116struct SearchPolicyKey {
117 /*! \brief Always apply unroll to the inner most iterator of the specificed iterators. */
118 static constexpr const char* always_unroll_inner = "auto_scheduler_always_unroll_inner";
119 /*! \brief The specified iterators will be placed in the inner most tile without split. */
120 static constexpr const char* no_split_at_inner = "auto_scheduler_no_split_at_inner";
121 /*! \brief The specified iterators are indices of const tensors in "fake reduction". */
122 static constexpr const char* simplify_const_tensor_indices =
123 "auto_scheduler_simplify_const_tensor_indices";
124};
125
126/*!
127 * \brief The base class of search policies.
128 */
129class SearchPolicyNode : public Object {
130 public:
131 /*! \brief The current search task. */
132 SearchTask search_task;
133 /*!
134 * \brief Verbose level to control the screen output during schedule search.
135 * 0 for silent, 1 to output state & measure information during search process.
136 */
137 int verbose;
138
139 void VisitAttrs(AttrVisitor* v) {
140 v->Visit("search_task", &search_task);
141 v->Visit("verbose", &verbose);
142 }
143
144 /*!
145 * \brief Do schedule search for a task. Takes the SearchTask as input and returns the best state
146 * found during the search.
147 * \param num_measure_trials The number of total measurement trials.
148 * \param early_stopping Stops the tuning early if no improvement after n measurements.
149 * \param num_measures_per_round The number of programs to be measured at each search round.
150 * \param measurer A ProgramMeasurer to build and measure programs
151 * \return The best state found.
152 */
153 virtual State Search(int num_measure_trials, int early_stopping, int num_measures_per_round,
154 ProgramMeasurer measurer) = 0;
155
156 /*!
157 * \brief Continue the search by doing an additional search round.
158 * \param num_measure The number of measurements
159 * \param measurer The measurer to measure programs
160 * \return The measurement records for measurements in this search round
161 */
162 virtual std::pair<Array<MeasureInput>, Array<MeasureResult>> ContinueSearchOneRound(
163 int num_measure, ProgramMeasurer measurer) = 0;
164
165 /*!
166 * \brief Preload measured states from a log file to resume the state of the search policy.
167 * \param log_file The name of the record log file.
168 */
169 void PreloadMeasuredStates(const String& log_file);
170
171 /*!
172 * \brief Call SearchCallback with the current SearchPolicyNode
173 * \param callbacks SearchCallback to be called.
174 */
175 void RunCallbacks(const Array<SearchCallback>& callbacks);
176
177 static constexpr const char* _type_key = "auto_scheduler.SearchPolicy";
178 TVM_DECLARE_BASE_OBJECT_INFO(SearchPolicyNode, Object);
179
180 protected:
181 /*!
182 * \brief The set of already measured states.
183 * We store the string format of a state for redundancy check. This is used to make sure a
184 * measured state will never be measured again.
185 */
186 std::unordered_set<std::string> measured_states_set_;
187 /*! \brief The array of already measured states.
188 * The good states can be used as the initial population in evolutionary search. */
189 std::vector<State> measured_states_vector_;
190 /*! \brief The throughputs of already measured states */
191 std::vector<float> measured_states_throughputs_;
192};
193
194/*!
195 * \brief Managed reference to SearchPolicyNode.
196 * \sa SearchPolicyNode
197 */
198class SearchPolicy : public ObjectRef {
199 public:
200 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchPolicy, ObjectRef, SearchPolicyNode);
201};
202
203} // namespace auto_scheduler
204} // namespace tvm
205
206#endif // TVM_AUTO_SCHEDULER_SEARCH_POLICY_H_
207