1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #ifndef TVM_META_SCHEDULE_TASK_SCHEDULER_H_ |
20 | #define TVM_META_SCHEDULE_TASK_SCHEDULER_H_ |
21 | |
22 | #include <tvm/meta_schedule/builder.h> |
23 | #include <tvm/meta_schedule/cost_model.h> |
24 | #include <tvm/meta_schedule/measure_callback.h> |
25 | #include <tvm/meta_schedule/runner.h> |
26 | #include <tvm/meta_schedule/tune_context.h> |
27 | #include <tvm/node/reflection.h> |
28 | #include <tvm/runtime/container/array.h> |
29 | #include <tvm/runtime/container/optional.h> |
30 | #include <tvm/runtime/object.h> |
31 | #include <tvm/runtime/packed_func.h> |
32 | #include <tvm/support/random_engine.h> |
33 | |
34 | #include <string> |
35 | #include <vector> |
36 | |
37 | namespace tvm { |
38 | namespace meta_schedule { |
39 | |
40 | class TaskRecordNode : public runtime::Object { |
41 | public: |
42 | /*! \brief The tune context of the task. */ |
43 | TuneContext ctx{nullptr}; |
44 | /*! \brief The weight of the task */ |
45 | double task_weight{1.0}; |
46 | /*! \brief The FLOP count of the task */ |
47 | double flop{1.0}; |
48 | /*! \brief Whether the tuning task has been stopped or finished. */ |
49 | bool is_terminated = false; |
50 | /*! \brief Builder errors happens in the task */ |
51 | int build_error_count = 0; |
52 | /*! \brief Runner errors happens in the task */ |
53 | int run_error_count = 0; |
54 | /*! \brief The latency of each run, in milliseconds. */ |
55 | std::vector<double> latency_ms = {}; |
56 | /*! \brief The measure candidates. */ |
57 | Optional<Array<MeasureCandidate>> measure_candidates = NullOpt; |
58 | /*! \brief The building results. */ |
59 | Optional<Array<BuilderResult>> builder_results = NullOpt; |
60 | /*! \brief Packed functions to fetch the runner results asynchronously. */ |
61 | Optional<Array<RunnerFuture>> runner_futures = NullOpt; |
62 | |
63 | void VisitAttrs(tvm::AttrVisitor* v) { |
64 | v->Visit("ctx" , &ctx); |
65 | v->Visit("task_weight" , &task_weight); |
66 | v->Visit("flop" , &flop); |
67 | v->Visit("is_terminated" , &is_terminated); |
68 | v->Visit("build_error_count" , &build_error_count); |
69 | v->Visit("run_error_count" , &run_error_count); |
70 | // `latency_ms` is not visited |
71 | v->Visit("measure_candidates" , &measure_candidates); |
72 | v->Visit("builder_results" , &builder_results); |
73 | v->Visit("runner_futures" , &runner_futures); |
74 | } |
75 | |
76 | static constexpr const char* _type_key = "meta_schedule.TaskRecord" ; |
77 | TVM_DECLARE_FINAL_OBJECT_INFO(TaskRecordNode, Object); |
78 | }; |
79 | |
80 | /*! |
81 | * \brief Managed reference to TaskRecordNode. |
82 | * \sa TaskRecordNode |
83 | */ |
84 | class TaskRecord : public runtime::ObjectRef { |
85 | public: |
86 | /*! \brief Constructor */ |
87 | explicit TaskRecord(TuneContext task, double task_weight); |
88 | |
89 | TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TaskRecord, ObjectRef, TaskRecordNode); |
90 | }; |
91 | |
92 | /*! |
93 | * \brief The abstract interface of task schedulers. |
94 | * \note The relationship between SpaceGenerator and other classes are as follows: |
95 | ┌──────────────────────────────────────────────────────────────┐ |
96 | ┌──┴───────────────────────────────────────────────────────────┐ │ |
97 | ┌──┴────────────────── Tune Context ───────────────────────────┐ │ │ |
98 | │ ┌─────────────────────┐ │ │ │ |
99 | │ │ │ Generate │ │ │ |
100 | │ │ Space Generator ├──────────────┐ │ │ │ |
101 | │ │ │ │ │ │ │ |
102 | │ └─────────────────────┘ ▼ │ │ │ |
103 | │ Design Space │ │ │ |
104 | │ ┌─────────────────────┐ │ │ │ │ |
105 | │ Generate │ │ Pretuning │ │ │ │ |
106 | │ ┌───────────┤ Search Strategy │◄─────────────┘ │ │ │ |
107 | │ │ │ │ │ ├──┘ |
108 | │ │ └─────────────────────┘ ├──┘ |
109 | └────┼─────────────────────────────────────────────────────────┘ |
110 | │ |
111 | │ |
112 | ┌────┼──────────────── Managed By Task Scheduler ─────────────────────┐ |
113 | │ │ ┌───────────┐ │ |
114 | │ │ Send to │ │ Send to │ |
115 | │ ▼ ┌─────────────►│ Builder ├──────────┐ │ |
116 | │ Measure Candidate │ Builder │ │ Runner │ │ |
117 | │ │ │ └───────────┘ │ │ |
118 | │ │ ┌────────────┴────────┐ │ │ |
119 | │ │ │ │ ┌───────────┐ │ │ |
120 | │ └────►│ Task Scheduler │ │ │ │ │ |
121 | │ │ │ │ Runner │◄─────────┘ │ |
122 | │ └─────────────────────┘ │ │ │ |
123 | │ ▲ └─────┬─────┘ │ |
124 | │ │ │ │ |
125 | │ └─── Runner Future ◄────┘ │ |
126 | └─────────────────────────────────────────────────────────────────────┘ |
127 | */ |
128 | class TaskSchedulerNode : public runtime::Object { |
129 | public: |
130 | /*! \brief The tuning task's logging function. */ |
131 | PackedFunc logger; |
132 | /*! \brief Records for each task */ |
133 | Array<TaskRecord> tasks_; |
134 | /*! \brief The list of measure callbacks of the scheduler. */ |
135 | Array<MeasureCallback> measure_callbacks_; |
136 | /*! \brief The database used in tuning */ |
137 | Optional<Database> database_; |
138 | /*! \brief The cost model used in tuning */ |
139 | Optional<CostModel> cost_model_; |
140 | /*! \brief The number of remaining tasks to be tuned. */ |
141 | int remaining_tasks_; |
142 | |
143 | /*! \brief The default destructor. */ |
144 | virtual ~TaskSchedulerNode() = default; |
145 | |
146 | void VisitAttrs(tvm::AttrVisitor* v) { |
147 | // `logger` is not visited |
148 | v->Visit("tasks_" , &tasks_); |
149 | v->Visit("measure_callbacks_" , &measure_callbacks_); |
150 | v->Visit("database_" , &database_); |
151 | v->Visit("cost_model_" , &cost_model_); |
152 | v->Visit("remaining_tasks_" , &remaining_tasks_); |
153 | } |
154 | |
155 | /*! |
156 | * \brief Fetch the next task id. |
157 | * \return The next task id. |
158 | */ |
159 | virtual int NextTaskId() = 0; |
160 | /*! |
161 | * \brief Wait until the task is finished. |
162 | * \param task_id The task id to be joined. |
163 | * \return The results from the runner. |
164 | */ |
165 | virtual Array<RunnerResult> JoinRunningTask(int task_id); |
166 | /*! |
167 | * \brief Jointly tune a given list of tasks. |
168 | * \param tasks The tasks to be tuned |
169 | * \param task_weights The weight of each task |
170 | * \param max_trials_global The maximum number of trials to be performed globally |
171 | * \param max_trials_per_task The maximum number of trials to be performed for each task |
172 | * \param num_trials_per_iter The number of trials to be performed in each iteration |
173 | * \param builder The MetaSchedule builder |
174 | * \param runner The MetaSchedule runner |
175 | * \param measure_callbacks The callbacks to be called after each measurement |
176 | * \param database The database used in tuning |
177 | * \param cost_model The cost model used in tuning |
178 | */ |
179 | virtual void Tune(Array<TuneContext> tasks, // |
180 | Array<FloatImm> task_weights, // |
181 | int max_trials_global, // |
182 | int max_trials_per_task, // |
183 | int num_trials_per_iter, // |
184 | Builder builder, // |
185 | Runner runner, // |
186 | Array<MeasureCallback> measure_callbacks, // |
187 | Optional<Database> database, // |
188 | Optional<CostModel> cost_model); |
189 | /*! |
190 | * \brief Terminate a task |
191 | * \param task_id The id of the task to be terminated |
192 | */ |
193 | void TerminateTask(int task_id); |
194 | /*! |
195 | * \brief Touch the task and update its status |
196 | * \param task_id The task id to be checked. |
197 | */ |
198 | void TouchTask(int task_id); |
199 | /*! \brief Print out a human-readable format of the tuning statistics. */ |
200 | void PrintTuningStatistics(); |
201 | |
202 | static constexpr const char* _type_key = "meta_schedule.TaskScheduler" ; |
203 | TVM_DECLARE_BASE_OBJECT_INFO(TaskSchedulerNode, Object); |
204 | }; |
205 | |
206 | class TaskScheduler; |
207 | |
208 | /*! \brief The task scheduler with customized methods on the python-side. */ |
209 | class PyTaskSchedulerNode : public TaskSchedulerNode { |
210 | public: |
211 | /*! |
212 | * \brief The function type of `NextTaskId` method. |
213 | * \return The next task id. |
214 | */ |
215 | using FNextTaskId = runtime::TypedPackedFunc<int()>; |
216 | /*! |
217 | * \brief The function type of `JoinRunningTask` method. |
218 | * \param task_id The task id to be joined. |
219 | */ |
220 | using FJoinRunningTask = runtime::TypedPackedFunc<Array<RunnerResult>(int)>; |
221 | /*! \brief The function type of `Tune` method. */ |
222 | using FTune = runtime::TypedPackedFunc<void(Array<TuneContext> tasks, // |
223 | Array<FloatImm> task_weights, // |
224 | int max_trials_global, // |
225 | int max_trials_per_task, // |
226 | int num_trials_per_iter, // |
227 | Builder builder, // |
228 | Runner runner, // |
229 | Array<MeasureCallback> measure_callbacks, // |
230 | Optional<Database> database, // |
231 | Optional<CostModel> cost_model)>; |
232 | |
233 | /*! \brief The packed function to the `NextTaskId` function. */ |
234 | FNextTaskId f_next_task_id; |
235 | /*! \brief The packed function to the `JoinRunningTask` function. */ |
236 | FJoinRunningTask f_join_running_task; |
237 | /*! \brief The packed function to the `Tune` function. */ |
238 | FTune f_tune; |
239 | |
240 | void VisitAttrs(tvm::AttrVisitor* v) { |
241 | TaskSchedulerNode::VisitAttrs(v); |
242 | // `f_next_task_id` is not visited |
243 | // `f_join_running_task` is not visited |
244 | // `f_tune` is not visited |
245 | } |
246 | |
247 | int NextTaskId() final; |
248 | Array<RunnerResult> JoinRunningTask(int task_id) final; |
249 | void Tune(Array<TuneContext> tasks, Array<FloatImm> task_weights, int max_trials_global, |
250 | int max_trials_per_task, int num_trials_per_iter, Builder builder, Runner runner, |
251 | Array<MeasureCallback> measure_callbacks, Optional<Database> database, |
252 | Optional<CostModel> cost_model) final; |
253 | |
254 | static constexpr const char* _type_key = "meta_schedule.PyTaskScheduler" ; |
255 | TVM_DECLARE_FINAL_OBJECT_INFO(PyTaskSchedulerNode, TaskSchedulerNode); |
256 | }; |
257 | |
258 | /*! |
259 | * \brief Managed reference to TaskSchedulerNode. |
260 | * \sa TaskSchedulerNode |
261 | */ |
262 | class TaskScheduler : public runtime::ObjectRef { |
263 | public: |
264 | /*! |
265 | * \brief Create a task scheduler that fetches tasks in a round-robin fashion. |
266 | * \param logger The tuning task's logging function. |
267 | * \return The task scheduler created. |
268 | */ |
269 | TVM_DLL static TaskScheduler RoundRobin(PackedFunc logger); |
270 | /*! |
271 | * \brief Create a task scheduler that fetches tasks in a gradient based fashion. |
272 | * \param logger The tuning task's logging function. |
273 | * \param alpha The parameter alpha to control gradient computation. |
274 | * \param window_size The parameter to control backward window size. |
275 | * \param seed The random seed. |
276 | * \return The task scheduler created. |
277 | */ |
278 | TVM_DLL static TaskScheduler GradientBased(PackedFunc logger, double alpha, int window_size, |
279 | support::LinearCongruentialEngine::TRandState seed); |
280 | /*! |
281 | * \brief Create a task scheduler with customized methods on the python-side. |
282 | * \param logger The tuning task's logging function. |
283 | * \param f_next_task_id The packed function of `NextTaskId`. |
284 | * \param f_join_running_task The packed function of `JoinRunningTask`. |
285 | * \param f_tune The packed function of `Tune`. |
286 | * \return The task scheduler created. |
287 | */ |
288 | TVM_DLL static TaskScheduler PyTaskScheduler( |
289 | PackedFunc logger, PyTaskSchedulerNode::FNextTaskId f_next_task_id, |
290 | PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, PyTaskSchedulerNode::FTune f_tune); |
291 | TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TaskScheduler, ObjectRef, TaskSchedulerNode); |
292 | }; |
293 | |
294 | } // namespace meta_schedule |
295 | } // namespace tvm |
296 | |
297 | #endif // TVM_META_SCHEDULE_TASK_SCHEDULER_H_ |
298 | |