1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#ifndef TVM_META_SCHEDULE_TASK_SCHEDULER_H_
20#define TVM_META_SCHEDULE_TASK_SCHEDULER_H_
21
22#include <tvm/meta_schedule/builder.h>
23#include <tvm/meta_schedule/cost_model.h>
24#include <tvm/meta_schedule/measure_callback.h>
25#include <tvm/meta_schedule/runner.h>
26#include <tvm/meta_schedule/tune_context.h>
27#include <tvm/node/reflection.h>
28#include <tvm/runtime/container/array.h>
29#include <tvm/runtime/container/optional.h>
30#include <tvm/runtime/object.h>
31#include <tvm/runtime/packed_func.h>
32#include <tvm/support/random_engine.h>
33
34#include <string>
35#include <vector>
36
37namespace tvm {
38namespace meta_schedule {
39
40class TaskRecordNode : public runtime::Object {
41 public:
42 /*! \brief The tune context of the task. */
43 TuneContext ctx{nullptr};
44 /*! \brief The weight of the task */
45 double task_weight{1.0};
46 /*! \brief The FLOP count of the task */
47 double flop{1.0};
48 /*! \brief Whether the tuning task has been stopped or finished. */
49 bool is_terminated = false;
50 /*! \brief Builder errors happens in the task */
51 int build_error_count = 0;
52 /*! \brief Runner errors happens in the task */
53 int run_error_count = 0;
54 /*! \brief The latency of each run, in milliseconds. */
55 std::vector<double> latency_ms = {};
56 /*! \brief The measure candidates. */
57 Optional<Array<MeasureCandidate>> measure_candidates = NullOpt;
58 /*! \brief The building results. */
59 Optional<Array<BuilderResult>> builder_results = NullOpt;
60 /*! \brief Packed functions to fetch the runner results asynchronously. */
61 Optional<Array<RunnerFuture>> runner_futures = NullOpt;
62
63 void VisitAttrs(tvm::AttrVisitor* v) {
64 v->Visit("ctx", &ctx);
65 v->Visit("task_weight", &task_weight);
66 v->Visit("flop", &flop);
67 v->Visit("is_terminated", &is_terminated);
68 v->Visit("build_error_count", &build_error_count);
69 v->Visit("run_error_count", &run_error_count);
70 // `latency_ms` is not visited
71 v->Visit("measure_candidates", &measure_candidates);
72 v->Visit("builder_results", &builder_results);
73 v->Visit("runner_futures", &runner_futures);
74 }
75
76 static constexpr const char* _type_key = "meta_schedule.TaskRecord";
77 TVM_DECLARE_FINAL_OBJECT_INFO(TaskRecordNode, Object);
78};
79
80/*!
81 * \brief Managed reference to TaskRecordNode.
82 * \sa TaskRecordNode
83 */
84class TaskRecord : public runtime::ObjectRef {
85 public:
86 /*! \brief Constructor */
87 explicit TaskRecord(TuneContext task, double task_weight);
88
89 TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TaskRecord, ObjectRef, TaskRecordNode);
90};
91
92/*!
93 * \brief The abstract interface of task schedulers.
94 * \note The relationship between SpaceGenerator and other classes are as follows:
95 ┌──────────────────────────────────────────────────────────────┐
96 ┌──┴───────────────────────────────────────────────────────────┐ │
97┌──┴────────────────── Tune Context ───────────────────────────┐ │ │
98│ ┌─────────────────────┐ │ │ │
99│ │ │ Generate │ │ │
100│ │ Space Generator ├──────────────┐ │ │ │
101│ │ │ │ │ │ │
102│ └─────────────────────┘ ▼ │ │ │
103│ Design Space │ │ │
104│ ┌─────────────────────┐ │ │ │ │
105│ Generate │ │ Pretuning │ │ │ │
106│ ┌───────────┤ Search Strategy │◄─────────────┘ │ │ │
107│ │ │ │ │ ├──┘
108│ │ └─────────────────────┘ ├──┘
109└────┼─────────────────────────────────────────────────────────┘
110
111
112┌────┼──────────────── Managed By Task Scheduler ─────────────────────┐
113│ │ ┌───────────┐ │
114│ │ Send to │ │ Send to │
115│ ▼ ┌─────────────►│ Builder ├──────────┐ │
116│ Measure Candidate │ Builder │ │ Runner │ │
117│ │ │ └───────────┘ │ │
118│ │ ┌────────────┴────────┐ │ │
119│ │ │ │ ┌───────────┐ │ │
120│ └────►│ Task Scheduler │ │ │ │ │
121│ │ │ │ Runner │◄─────────┘ │
122│ └─────────────────────┘ │ │ │
123│ ▲ └─────┬─────┘ │
124│ │ │ │
125│ └─── Runner Future ◄────┘ │
126└─────────────────────────────────────────────────────────────────────┘
127*/
128class TaskSchedulerNode : public runtime::Object {
129 public:
130 /*! \brief The tuning task's logging function. */
131 PackedFunc logger;
132 /*! \brief Records for each task */
133 Array<TaskRecord> tasks_;
134 /*! \brief The list of measure callbacks of the scheduler. */
135 Array<MeasureCallback> measure_callbacks_;
136 /*! \brief The database used in tuning */
137 Optional<Database> database_;
138 /*! \brief The cost model used in tuning */
139 Optional<CostModel> cost_model_;
140 /*! \brief The number of remaining tasks to be tuned. */
141 int remaining_tasks_;
142
143 /*! \brief The default destructor. */
144 virtual ~TaskSchedulerNode() = default;
145
146 void VisitAttrs(tvm::AttrVisitor* v) {
147 // `logger` is not visited
148 v->Visit("tasks_", &tasks_);
149 v->Visit("measure_callbacks_", &measure_callbacks_);
150 v->Visit("database_", &database_);
151 v->Visit("cost_model_", &cost_model_);
152 v->Visit("remaining_tasks_", &remaining_tasks_);
153 }
154
155 /*!
156 * \brief Fetch the next task id.
157 * \return The next task id.
158 */
159 virtual int NextTaskId() = 0;
160 /*!
161 * \brief Wait until the task is finished.
162 * \param task_id The task id to be joined.
163 * \return The results from the runner.
164 */
165 virtual Array<RunnerResult> JoinRunningTask(int task_id);
166 /*!
167 * \brief Jointly tune a given list of tasks.
168 * \param tasks The tasks to be tuned
169 * \param task_weights The weight of each task
170 * \param max_trials_global The maximum number of trials to be performed globally
171 * \param max_trials_per_task The maximum number of trials to be performed for each task
172 * \param num_trials_per_iter The number of trials to be performed in each iteration
173 * \param builder The MetaSchedule builder
174 * \param runner The MetaSchedule runner
175 * \param measure_callbacks The callbacks to be called after each measurement
176 * \param database The database used in tuning
177 * \param cost_model The cost model used in tuning
178 */
179 virtual void Tune(Array<TuneContext> tasks, //
180 Array<FloatImm> task_weights, //
181 int max_trials_global, //
182 int max_trials_per_task, //
183 int num_trials_per_iter, //
184 Builder builder, //
185 Runner runner, //
186 Array<MeasureCallback> measure_callbacks, //
187 Optional<Database> database, //
188 Optional<CostModel> cost_model);
189 /*!
190 * \brief Terminate a task
191 * \param task_id The id of the task to be terminated
192 */
193 void TerminateTask(int task_id);
194 /*!
195 * \brief Touch the task and update its status
196 * \param task_id The task id to be checked.
197 */
198 void TouchTask(int task_id);
199 /*! \brief Print out a human-readable format of the tuning statistics. */
200 void PrintTuningStatistics();
201
202 static constexpr const char* _type_key = "meta_schedule.TaskScheduler";
203 TVM_DECLARE_BASE_OBJECT_INFO(TaskSchedulerNode, Object);
204};
205
206class TaskScheduler;
207
208/*! \brief The task scheduler with customized methods on the python-side. */
209class PyTaskSchedulerNode : public TaskSchedulerNode {
210 public:
211 /*!
212 * \brief The function type of `NextTaskId` method.
213 * \return The next task id.
214 */
215 using FNextTaskId = runtime::TypedPackedFunc<int()>;
216 /*!
217 * \brief The function type of `JoinRunningTask` method.
218 * \param task_id The task id to be joined.
219 */
220 using FJoinRunningTask = runtime::TypedPackedFunc<Array<RunnerResult>(int)>;
221 /*! \brief The function type of `Tune` method. */
222 using FTune = runtime::TypedPackedFunc<void(Array<TuneContext> tasks, //
223 Array<FloatImm> task_weights, //
224 int max_trials_global, //
225 int max_trials_per_task, //
226 int num_trials_per_iter, //
227 Builder builder, //
228 Runner runner, //
229 Array<MeasureCallback> measure_callbacks, //
230 Optional<Database> database, //
231 Optional<CostModel> cost_model)>;
232
233 /*! \brief The packed function to the `NextTaskId` function. */
234 FNextTaskId f_next_task_id;
235 /*! \brief The packed function to the `JoinRunningTask` function. */
236 FJoinRunningTask f_join_running_task;
237 /*! \brief The packed function to the `Tune` function. */
238 FTune f_tune;
239
240 void VisitAttrs(tvm::AttrVisitor* v) {
241 TaskSchedulerNode::VisitAttrs(v);
242 // `f_next_task_id` is not visited
243 // `f_join_running_task` is not visited
244 // `f_tune` is not visited
245 }
246
247 int NextTaskId() final;
248 Array<RunnerResult> JoinRunningTask(int task_id) final;
249 void Tune(Array<TuneContext> tasks, Array<FloatImm> task_weights, int max_trials_global,
250 int max_trials_per_task, int num_trials_per_iter, Builder builder, Runner runner,
251 Array<MeasureCallback> measure_callbacks, Optional<Database> database,
252 Optional<CostModel> cost_model) final;
253
254 static constexpr const char* _type_key = "meta_schedule.PyTaskScheduler";
255 TVM_DECLARE_FINAL_OBJECT_INFO(PyTaskSchedulerNode, TaskSchedulerNode);
256};
257
258/*!
259 * \brief Managed reference to TaskSchedulerNode.
260 * \sa TaskSchedulerNode
261 */
262class TaskScheduler : public runtime::ObjectRef {
263 public:
264 /*!
265 * \brief Create a task scheduler that fetches tasks in a round-robin fashion.
266 * \param logger The tuning task's logging function.
267 * \return The task scheduler created.
268 */
269 TVM_DLL static TaskScheduler RoundRobin(PackedFunc logger);
270 /*!
271 * \brief Create a task scheduler that fetches tasks in a gradient based fashion.
272 * \param logger The tuning task's logging function.
273 * \param alpha The parameter alpha to control gradient computation.
274 * \param window_size The parameter to control backward window size.
275 * \param seed The random seed.
276 * \return The task scheduler created.
277 */
278 TVM_DLL static TaskScheduler GradientBased(PackedFunc logger, double alpha, int window_size,
279 support::LinearCongruentialEngine::TRandState seed);
280 /*!
281 * \brief Create a task scheduler with customized methods on the python-side.
282 * \param logger The tuning task's logging function.
283 * \param f_next_task_id The packed function of `NextTaskId`.
284 * \param f_join_running_task The packed function of `JoinRunningTask`.
285 * \param f_tune The packed function of `Tune`.
286 * \return The task scheduler created.
287 */
288 TVM_DLL static TaskScheduler PyTaskScheduler(
289 PackedFunc logger, PyTaskSchedulerNode::FNextTaskId f_next_task_id,
290 PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, PyTaskSchedulerNode::FTune f_tune);
291 TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TaskScheduler, ObjectRef, TaskSchedulerNode);
292};
293
294} // namespace meta_schedule
295} // namespace tvm
296
297#endif // TVM_META_SCHEDULE_TASK_SCHEDULER_H_
298