1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace meta_schedule {
23
24TaskRecord::TaskRecord(TuneContext ctx, double task_weight) {
25 ObjectPtr<TaskRecordNode> n = runtime::make_object<TaskRecordNode>();
26 n->ctx = ctx;
27 n->task_weight = task_weight;
28 n->flop = 1.0;
29 auto _ = Profiler::TimedScope("InitializeTask");
30 CHECK(ctx->mod.defined()) << "ValueError: Require `context.mod`, but it is not defined";
31 CHECK(ctx->space_generator.defined())
32 << "ValueError: Require `context.space_generator`, but it is not defined";
33 CHECK(ctx->search_strategy.defined())
34 << "ValueError: Require `context.search_strategy`, but it is not defined";
35 TVM_PY_LOG(INFO, ctx->logger) << "\n" << ctx->mod;
36 ctx->Initialize();
37 n->flop = std::max(1.0, tir::EstimateTIRFlops(ctx->mod.value()));
38 this->data_ = std::move(n);
39}
40
41void SendToBuilder(TaskRecordNode* self, const Builder& builder) {
42 auto _ = Profiler::TimedScope("SendToBuilder");
43 Array<MeasureCandidate> candidates = self->measure_candidates.value();
44 Target target = self->ctx->target.value();
45 Array<BuilderInput> inputs;
46 inputs.reserve(candidates.size());
47 for (const MeasureCandidate& candidate : candidates) {
48 inputs.push_back(BuilderInput(candidate->sch->mod(), target));
49 }
50 self->builder_results = builder->Build(inputs);
51}
52
53void SendToRunner(TaskRecordNode* self, const Runner& runner) {
54 auto _ = Profiler::TimedScope("SendToRunner");
55 Array<MeasureCandidate> candidates = self->measure_candidates.value();
56 Array<BuilderResult> builder_results = self->builder_results.value();
57 Target target = self->ctx->target.value();
58 ICHECK_EQ(candidates.size(), builder_results.size());
59 int n = candidates.size();
60 int n_build_errors = 0;
61 Array<RunnerInput> inputs;
62 inputs.reserve(n);
63 for (int i = 0; i < n; ++i) {
64 const MeasureCandidate& candidate = candidates[i];
65 const BuilderResult& builder_result = builder_results[i];
66 if (builder_result->error_msg.defined()) {
67 ++n_build_errors;
68 continue;
69 }
70 inputs.push_back(RunnerInput(/*artifact_path=*/builder_result->artifact_path.value(),
71 /*device_type=*/target->kind->name,
72 /*args_info=*/candidate->args_info));
73 }
74 Array<RunnerFuture> futures = runner->Run(inputs);
75 if (n_build_errors == 0) {
76 self->runner_futures = futures;
77 return;
78 }
79 Array<RunnerFuture> results;
80 results.reserve(n);
81 for (int i = 0, j = 0; i < n; ++i) {
82 const BuilderResult& builder_result = builder_results[i];
83 if (builder_result->error_msg.defined()) {
84 results.push_back(RunnerFuture(
85 /*f_done=*/[]() -> bool { return true; },
86 /*f_result=*/
87 [msg = builder_result->error_msg]() -> RunnerResult {
88 return RunnerResult(NullOpt, msg);
89 }));
90 } else {
91 results.push_back(futures[j++]);
92 }
93 }
94 self->runner_futures = results;
95}
96
97void TaskCleanUp(TaskRecordNode* self, int task_id, const Array<RunnerResult>& results) {
98 ICHECK_EQ(self->builder_results.value().size(), results.size());
99 ICHECK_EQ(self->runner_futures.value().size(), results.size());
100 int n = results.size();
101 std::string name = self->ctx->task_name.value();
102 const PackedFunc& logger = self->ctx->logger;
103 for (int i = 0; i < n; ++i) {
104 const BuilderResult& builder_result = self->builder_results.value()[i];
105 const MeasureCandidate& candidate = self->measure_candidates.value()[i];
106 const RunnerResult& runner_result = results[i];
107 Optional<String> error_msg = NullOpt;
108 int trials = self->latency_ms.size() + 1;
109 double run_ms = 1e9;
110 if ((error_msg = builder_result->error_msg)) {
111 ++self->build_error_count;
112 } else if ((error_msg = runner_result->error_msg)) {
113 ++self->run_error_count;
114 } else {
115 run_ms = GetRunMsMedian(runner_result);
116 }
117 self->latency_ms.push_back(run_ms);
118 if (error_msg) {
119 const tir::Schedule& sch = candidate->sch;
120 std::string err = error_msg.value();
121 TVM_PY_LOG(INFO, logger) << std::fixed << std::setprecision(4) //
122 << "[Task #" << task_id << ": " << name << "] Trial #" << trials
123 << ": Error in "
124 << (builder_result->error_msg.defined() ? "building" : "running")
125 << ":\n"
126 << err << "\n"
127 << sch->mod() << "\n"
128 << Concat(sch->trace().value()->AsPython(false), "\n");
129 } else {
130 double best_ms = *std::min_element(self->latency_ms.begin(), self->latency_ms.end());
131 TVM_PY_LOG(INFO, logger) << std::fixed << std::setprecision(4) //
132 << "[Task #" << task_id << ": " << name << "] Trial #" << trials
133 << ": GFLOPs: " << (self->flop / run_ms / 1e6)
134 << ". Time: " << (run_ms * 1e3) << " us"
135 << ". Best GFLOPs: " << (self->flop / best_ms / 1e6);
136 }
137 }
138 self->measure_candidates = NullOpt;
139 self->builder_results = NullOpt;
140 self->runner_futures = NullOpt;
141}
142
143void TaskSchedulerNode::Tune(Array<TuneContext> ctxs, Array<FloatImm> task_weights,
144 int max_trials_global, int max_trials_per_task,
145 int num_trials_per_iter, Builder builder, Runner runner,
146 Array<MeasureCallback> measure_callbacks, Optional<Database> database,
147 Optional<CostModel> cost_model) {
148 CHECK_EQ(ctxs.size(), task_weights.size()) << "ValueError: `task_weights` must have the same "
149 "length as `ctxs`";
150 int n_tasks = this->remaining_tasks_ = ctxs.size();
151 this->measure_callbacks_ = measure_callbacks;
152 this->database_ = database;
153 this->cost_model_ = cost_model;
154 this->tasks_.clear();
155 this->tasks_.reserve(n_tasks);
156 for (int i = 0; i < n_tasks; ++i) {
157 const TuneContext& ctx = ctxs[i];
158 double weight = task_weights[i]->value;
159 TVM_PY_LOG(INFO, this->logger) << "Initializing Task #" << i << ": " << ctx->task_name;
160 TVM_PY_LOG(INFO, ctx->logger) << "Initializing Task #" << i << ": " << ctx->task_name;
161 this->tasks_.push_back(TaskRecord(ctx, weight));
162 Array<tir::Schedule> design_spaces =
163 ctx->space_generator.value()->GenerateDesignSpace(ctx->mod.value());
164 TVM_PY_LOG(INFO, ctx->logger) << "Total " << design_spaces.size()
165 << " design space(s) generated";
166 for (int i = 0, n = design_spaces.size(); i < n; ++i) {
167 tir::Schedule sch = design_spaces[i];
168 tir::Trace trace = sch->trace().value();
169 trace = trace->Simplified(true);
170 TVM_PY_LOG(INFO, ctx->logger) << "Design space #" << i << ":\n"
171 << sch->mod() << "\n"
172 << Concat(trace->AsPython(false), "\n");
173 }
174 ctx->search_strategy.value()->PreTuning(max_trials_per_task, num_trials_per_iter, design_spaces,
175 database, cost_model);
176 }
177
178 int num_trials_already = 0;
179 for (int task_id; num_trials_already < max_trials_global && (task_id = NextTaskId()) != -1;) {
180 TVM_PY_LOG(INFO, this->logger)
181 << "TaskScheduler picks Task #" << task_id << ": " << tasks_[task_id]->ctx->task_name;
182 TaskRecordNode* task = tasks_[task_id].get();
183 ICHECK(!task->is_terminated);
184 ICHECK(!task->runner_futures.defined());
185 if (static_cast<int>(task->latency_ms.size()) >= max_trials_per_task) {
186 TerminateTask(task_id);
187 continue;
188 }
189 if (Optional<Array<MeasureCandidate>> candidates = task->measure_candidates =
190 task->ctx->search_strategy.value()->GenerateMeasureCandidates()) {
191 int num_candidates = candidates.value().size();
192 num_trials_already += num_candidates;
193 TVM_PY_LOG(INFO, this->logger) << "Sending " << num_candidates << " sample(s) to builder";
194 SendToBuilder(task, builder);
195 TVM_PY_LOG(INFO, this->logger) << "Sending " << num_candidates << " sample(s) to runner";
196 SendToRunner(task, runner);
197 } else {
198 TerminateTask(task_id);
199 }
200 }
201 for (int task_id = 0; task_id < n_tasks; ++task_id) {
202 TaskRecordNode* task = this->tasks_[task_id].get();
203 if (!task->is_terminated) {
204 if (task->runner_futures.defined()) {
205 JoinRunningTask(task_id);
206 }
207 TerminateTask(task_id);
208 }
209 task->ctx->search_strategy.value()->PostTuning();
210 }
211}
212
213Array<RunnerResult> TaskSchedulerNode::JoinRunningTask(int task_id) {
214 TaskRecordNode* task = this->tasks_[task_id].get();
215 ICHECK(task->runner_futures.defined());
216 Array<RunnerResult> results;
217 {
218 auto _ = Profiler::TimedScope("JoinRunnerFutures");
219 Array<RunnerFuture> futures = task->runner_futures.value();
220 results.reserve(futures.size());
221 for (RunnerFuture future : futures) {
222 results.push_back(future->Result());
223 }
224 }
225 ICHECK(task->measure_candidates.defined());
226 task->ctx->search_strategy.value()->NotifyRunnerResults(task->measure_candidates.value(),
227 results);
228 ICHECK(task->builder_results.defined());
229 ICHECK_EQ(results.size(), task->measure_candidates.value().size());
230 ICHECK_EQ(results.size(), task->builder_results.value().size());
231 for (const MeasureCallback& callback : this->measure_callbacks_) {
232 callback->Apply(GetRef<TaskScheduler>(this), task_id, task->measure_candidates.value(),
233 task->builder_results.value(), results);
234 }
235 TaskCleanUp(task, task_id, results);
236 TVM_PY_LOG_CLEAR_SCREEN(this->logger);
237 TVM_PY_LOG(INFO, this->logger) << "[Updated] Task #" << task_id << ": " << task->ctx->task_name;
238 this->PrintTuningStatistics();
239 return results;
240}
241
242void TaskSchedulerNode::TouchTask(int task_id) {
243 TaskRecordNode* task = this->tasks_[task_id].get();
244 if (!task->is_terminated && task->runner_futures.defined()) {
245 for (const RunnerFuture future : task->runner_futures.value()) {
246 if (!future->Done()) {
247 return;
248 }
249 }
250 this->JoinRunningTask(task_id);
251 }
252}
253
254void TaskSchedulerNode::TerminateTask(int task_id) {
255 TaskRecordNode* task = this->tasks_[task_id].get();
256 ICHECK(!task->is_terminated);
257 task->is_terminated = true;
258 --this->remaining_tasks_;
259 TVM_PY_LOG_CLEAR_SCREEN(this->logger);
260 TVM_PY_LOG(INFO, this->logger) << "Task #" << task_id
261 << " has finished. Remaining task(s): " << this->remaining_tasks_;
262 this->PrintTuningStatistics();
263}
264
265void TaskSchedulerNode::PrintTuningStatistics() {
266 std::ostringstream os;
267 int n_tasks = this->tasks_.size();
268 int total_trials = 0;
269 double total_latency = 0.0;
270 support::TablePrinter p;
271 p.Row() << "ID"
272 << "Name"
273 << "FLOP"
274 << "Weight"
275 << "Speed (GFLOPS)"
276 << "Latency (us)"
277 << "Weighted Latency (us)"
278 << "Trials"
279 << "Done";
280 p.Separator();
281 for (int i = 0; i < n_tasks; ++i) {
282 const TaskRecordNode* task = this->tasks_[i].get();
283 auto row = p.Row();
284 int trials = task->latency_ms.size();
285 row << /*id=*/i << /*name=*/task->ctx->task_name.value() //
286 << /*flops=*/static_cast<int64_t>(task->flop)
287 << /*weight=*/static_cast<int>(task->task_weight);
288 double latency_ms = 1e9;
289 if (!task->latency_ms.empty()) {
290 latency_ms = *std::min_element(task->latency_ms.begin(), task->latency_ms.end());
291 }
292 if (latency_ms >= 1e9) {
293 row << /*speed=*/"N/A" << /*latency=*/"N/A" << /*weighted_latency=*/"N/A";
294 } else {
295 latency_ms *= 1000.0;
296 double speed = task->flop / latency_ms / 1000.0;
297 double weighted_latency = latency_ms * task->task_weight;
298 row << /*speed=*/speed << /*latency=*/latency_ms << /*weighted_latency=*/weighted_latency;
299 total_latency += weighted_latency;
300 total_trials += trials;
301 }
302 row << trials;
303 if (task->is_terminated) {
304 row << "Y";
305 } else {
306 row << "";
307 }
308 }
309 p.Separator();
310
311 os << "\nTotal trials: " << total_trials //
312 << "\nTotal latency (us): " << total_latency //
313 << "\n";
314
315 if (using_ipython()) {
316 print_interactive_table(p.AsStr());
317 std::cout << os.str() << std::endl << std::flush;
318 TVM_PY_LOG(DEBUG, this->logger) << "\n" << p.AsStr() << os.str();
319 } else {
320 TVM_PY_LOG(INFO, this->logger) << "\n" << p.AsStr() << os.str();
321 }
322}
323
324TaskScheduler TaskScheduler::PyTaskScheduler(
325 PackedFunc logger, PyTaskSchedulerNode::FNextTaskId f_next_task_id,
326 PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, PyTaskSchedulerNode::FTune f_tune) {
327 CHECK(f_next_task_id != nullptr) << "ValueError: next_task_id is not defined";
328 ObjectPtr<PyTaskSchedulerNode> n = make_object<PyTaskSchedulerNode>();
329 n->logger = logger;
330 n->f_next_task_id = f_next_task_id;
331 n->f_join_running_task = f_join_running_task;
332 n->f_tune = f_tune;
333 return TaskScheduler(n);
334}
335
336int PyTaskSchedulerNode::NextTaskId() {
337 CHECK(f_next_task_id != nullptr) << "PyTaskScheduler's NextTaskId method not implemented!";
338 return f_next_task_id();
339}
340
341Array<RunnerResult> PyTaskSchedulerNode::JoinRunningTask(int task_id) {
342 if (f_join_running_task == nullptr) {
343 return TaskSchedulerNode::JoinRunningTask(task_id);
344 } else {
345 return f_join_running_task(task_id);
346 }
347}
348
349void PyTaskSchedulerNode::Tune(Array<TuneContext> tasks, Array<FloatImm> task_weights,
350 int max_trials_global, int max_trials_per_task,
351 int num_trials_per_iter, Builder builder, Runner runner,
352 Array<MeasureCallback> measure_callbacks,
353 Optional<Database> database, Optional<CostModel> cost_model) {
354 if (f_tune == nullptr) {
355 TaskSchedulerNode::Tune(tasks, task_weights, max_trials_global, max_trials_per_task,
356 num_trials_per_iter, builder, runner, measure_callbacks, database,
357 cost_model);
358 } else {
359 f_tune(tasks, task_weights, max_trials_global, max_trials_per_task, num_trials_per_iter,
360 builder, runner, measure_callbacks, database, cost_model);
361 }
362}
363
364TVM_REGISTER_NODE_TYPE(TaskRecordNode);
365TVM_REGISTER_OBJECT_TYPE(TaskSchedulerNode);
366TVM_REGISTER_NODE_TYPE(PyTaskSchedulerNode);
367TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerPyTaskScheduler")
368 .set_body_typed(TaskScheduler::PyTaskScheduler);
369TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTune")
370 .set_body_method<TaskScheduler>(&TaskSchedulerNode::Tune);
371TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerJoinRunningTask")
372 .set_body_method<TaskScheduler>(&TaskSchedulerNode::JoinRunningTask);
373TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerNextTaskId")
374 .set_body_method<TaskScheduler>(&TaskSchedulerNode::NextTaskId);
375TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTerminateTask")
376 .set_body_method<TaskScheduler>(&TaskSchedulerNode::TerminateTask);
377TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTouchTask")
378 .set_body_method<TaskScheduler>(&TaskSchedulerNode::TouchTask);
379TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerPrintTuningStatistics")
380 .set_body_method<TaskScheduler>(&TaskSchedulerNode::PrintTuningStatistics);
381
382} // namespace meta_schedule
383} // namespace tvm
384