1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "../utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace meta_schedule { |
23 | |
24 | TaskRecord::TaskRecord(TuneContext ctx, double task_weight) { |
25 | ObjectPtr<TaskRecordNode> n = runtime::make_object<TaskRecordNode>(); |
26 | n->ctx = ctx; |
27 | n->task_weight = task_weight; |
28 | n->flop = 1.0; |
29 | auto _ = Profiler::TimedScope("InitializeTask" ); |
30 | CHECK(ctx->mod.defined()) << "ValueError: Require `context.mod`, but it is not defined" ; |
31 | CHECK(ctx->space_generator.defined()) |
32 | << "ValueError: Require `context.space_generator`, but it is not defined" ; |
33 | CHECK(ctx->search_strategy.defined()) |
34 | << "ValueError: Require `context.search_strategy`, but it is not defined" ; |
35 | TVM_PY_LOG(INFO, ctx->logger) << "\n" << ctx->mod; |
36 | ctx->Initialize(); |
37 | n->flop = std::max(1.0, tir::EstimateTIRFlops(ctx->mod.value())); |
38 | this->data_ = std::move(n); |
39 | } |
40 | |
41 | void SendToBuilder(TaskRecordNode* self, const Builder& builder) { |
42 | auto _ = Profiler::TimedScope("SendToBuilder" ); |
43 | Array<MeasureCandidate> candidates = self->measure_candidates.value(); |
44 | Target target = self->ctx->target.value(); |
45 | Array<BuilderInput> inputs; |
46 | inputs.reserve(candidates.size()); |
47 | for (const MeasureCandidate& candidate : candidates) { |
48 | inputs.push_back(BuilderInput(candidate->sch->mod(), target)); |
49 | } |
50 | self->builder_results = builder->Build(inputs); |
51 | } |
52 | |
53 | void SendToRunner(TaskRecordNode* self, const Runner& runner) { |
54 | auto _ = Profiler::TimedScope("SendToRunner" ); |
55 | Array<MeasureCandidate> candidates = self->measure_candidates.value(); |
56 | Array<BuilderResult> builder_results = self->builder_results.value(); |
57 | Target target = self->ctx->target.value(); |
58 | ICHECK_EQ(candidates.size(), builder_results.size()); |
59 | int n = candidates.size(); |
60 | int n_build_errors = 0; |
61 | Array<RunnerInput> inputs; |
62 | inputs.reserve(n); |
63 | for (int i = 0; i < n; ++i) { |
64 | const MeasureCandidate& candidate = candidates[i]; |
65 | const BuilderResult& builder_result = builder_results[i]; |
66 | if (builder_result->error_msg.defined()) { |
67 | ++n_build_errors; |
68 | continue; |
69 | } |
70 | inputs.push_back(RunnerInput(/*artifact_path=*/builder_result->artifact_path.value(), |
71 | /*device_type=*/target->kind->name, |
72 | /*args_info=*/candidate->args_info)); |
73 | } |
74 | Array<RunnerFuture> futures = runner->Run(inputs); |
75 | if (n_build_errors == 0) { |
76 | self->runner_futures = futures; |
77 | return; |
78 | } |
79 | Array<RunnerFuture> results; |
80 | results.reserve(n); |
81 | for (int i = 0, j = 0; i < n; ++i) { |
82 | const BuilderResult& builder_result = builder_results[i]; |
83 | if (builder_result->error_msg.defined()) { |
84 | results.push_back(RunnerFuture( |
85 | /*f_done=*/[]() -> bool { return true; }, |
86 | /*f_result=*/ |
87 | [msg = builder_result->error_msg]() -> RunnerResult { |
88 | return RunnerResult(NullOpt, msg); |
89 | })); |
90 | } else { |
91 | results.push_back(futures[j++]); |
92 | } |
93 | } |
94 | self->runner_futures = results; |
95 | } |
96 | |
97 | void TaskCleanUp(TaskRecordNode* self, int task_id, const Array<RunnerResult>& results) { |
98 | ICHECK_EQ(self->builder_results.value().size(), results.size()); |
99 | ICHECK_EQ(self->runner_futures.value().size(), results.size()); |
100 | int n = results.size(); |
101 | std::string name = self->ctx->task_name.value(); |
102 | const PackedFunc& logger = self->ctx->logger; |
103 | for (int i = 0; i < n; ++i) { |
104 | const BuilderResult& builder_result = self->builder_results.value()[i]; |
105 | const MeasureCandidate& candidate = self->measure_candidates.value()[i]; |
106 | const RunnerResult& runner_result = results[i]; |
107 | Optional<String> error_msg = NullOpt; |
108 | int trials = self->latency_ms.size() + 1; |
109 | double run_ms = 1e9; |
110 | if ((error_msg = builder_result->error_msg)) { |
111 | ++self->build_error_count; |
112 | } else if ((error_msg = runner_result->error_msg)) { |
113 | ++self->run_error_count; |
114 | } else { |
115 | run_ms = GetRunMsMedian(runner_result); |
116 | } |
117 | self->latency_ms.push_back(run_ms); |
118 | if (error_msg) { |
119 | const tir::Schedule& sch = candidate->sch; |
120 | std::string err = error_msg.value(); |
121 | TVM_PY_LOG(INFO, logger) << std::fixed << std::setprecision(4) // |
122 | << "[Task #" << task_id << ": " << name << "] Trial #" << trials |
123 | << ": Error in " |
124 | << (builder_result->error_msg.defined() ? "building" : "running" ) |
125 | << ":\n" |
126 | << err << "\n" |
127 | << sch->mod() << "\n" |
128 | << Concat(sch->trace().value()->AsPython(false), "\n" ); |
129 | } else { |
130 | double best_ms = *std::min_element(self->latency_ms.begin(), self->latency_ms.end()); |
131 | TVM_PY_LOG(INFO, logger) << std::fixed << std::setprecision(4) // |
132 | << "[Task #" << task_id << ": " << name << "] Trial #" << trials |
133 | << ": GFLOPs: " << (self->flop / run_ms / 1e6) |
134 | << ". Time: " << (run_ms * 1e3) << " us" |
135 | << ". Best GFLOPs: " << (self->flop / best_ms / 1e6); |
136 | } |
137 | } |
138 | self->measure_candidates = NullOpt; |
139 | self->builder_results = NullOpt; |
140 | self->runner_futures = NullOpt; |
141 | } |
142 | |
143 | void TaskSchedulerNode::Tune(Array<TuneContext> ctxs, Array<FloatImm> task_weights, |
144 | int max_trials_global, int max_trials_per_task, |
145 | int num_trials_per_iter, Builder builder, Runner runner, |
146 | Array<MeasureCallback> measure_callbacks, Optional<Database> database, |
147 | Optional<CostModel> cost_model) { |
148 | CHECK_EQ(ctxs.size(), task_weights.size()) << "ValueError: `task_weights` must have the same " |
149 | "length as `ctxs`" ; |
150 | int n_tasks = this->remaining_tasks_ = ctxs.size(); |
151 | this->measure_callbacks_ = measure_callbacks; |
152 | this->database_ = database; |
153 | this->cost_model_ = cost_model; |
154 | this->tasks_.clear(); |
155 | this->tasks_.reserve(n_tasks); |
156 | for (int i = 0; i < n_tasks; ++i) { |
157 | const TuneContext& ctx = ctxs[i]; |
158 | double weight = task_weights[i]->value; |
159 | TVM_PY_LOG(INFO, this->logger) << "Initializing Task #" << i << ": " << ctx->task_name; |
160 | TVM_PY_LOG(INFO, ctx->logger) << "Initializing Task #" << i << ": " << ctx->task_name; |
161 | this->tasks_.push_back(TaskRecord(ctx, weight)); |
162 | Array<tir::Schedule> design_spaces = |
163 | ctx->space_generator.value()->GenerateDesignSpace(ctx->mod.value()); |
164 | TVM_PY_LOG(INFO, ctx->logger) << "Total " << design_spaces.size() |
165 | << " design space(s) generated" ; |
166 | for (int i = 0, n = design_spaces.size(); i < n; ++i) { |
167 | tir::Schedule sch = design_spaces[i]; |
168 | tir::Trace trace = sch->trace().value(); |
169 | trace = trace->Simplified(true); |
170 | TVM_PY_LOG(INFO, ctx->logger) << "Design space #" << i << ":\n" |
171 | << sch->mod() << "\n" |
172 | << Concat(trace->AsPython(false), "\n" ); |
173 | } |
174 | ctx->search_strategy.value()->PreTuning(max_trials_per_task, num_trials_per_iter, design_spaces, |
175 | database, cost_model); |
176 | } |
177 | |
178 | int num_trials_already = 0; |
179 | for (int task_id; num_trials_already < max_trials_global && (task_id = NextTaskId()) != -1;) { |
180 | TVM_PY_LOG(INFO, this->logger) |
181 | << "TaskScheduler picks Task #" << task_id << ": " << tasks_[task_id]->ctx->task_name; |
182 | TaskRecordNode* task = tasks_[task_id].get(); |
183 | ICHECK(!task->is_terminated); |
184 | ICHECK(!task->runner_futures.defined()); |
185 | if (static_cast<int>(task->latency_ms.size()) >= max_trials_per_task) { |
186 | TerminateTask(task_id); |
187 | continue; |
188 | } |
189 | if (Optional<Array<MeasureCandidate>> candidates = task->measure_candidates = |
190 | task->ctx->search_strategy.value()->GenerateMeasureCandidates()) { |
191 | int num_candidates = candidates.value().size(); |
192 | num_trials_already += num_candidates; |
193 | TVM_PY_LOG(INFO, this->logger) << "Sending " << num_candidates << " sample(s) to builder" ; |
194 | SendToBuilder(task, builder); |
195 | TVM_PY_LOG(INFO, this->logger) << "Sending " << num_candidates << " sample(s) to runner" ; |
196 | SendToRunner(task, runner); |
197 | } else { |
198 | TerminateTask(task_id); |
199 | } |
200 | } |
201 | for (int task_id = 0; task_id < n_tasks; ++task_id) { |
202 | TaskRecordNode* task = this->tasks_[task_id].get(); |
203 | if (!task->is_terminated) { |
204 | if (task->runner_futures.defined()) { |
205 | JoinRunningTask(task_id); |
206 | } |
207 | TerminateTask(task_id); |
208 | } |
209 | task->ctx->search_strategy.value()->PostTuning(); |
210 | } |
211 | } |
212 | |
213 | Array<RunnerResult> TaskSchedulerNode::JoinRunningTask(int task_id) { |
214 | TaskRecordNode* task = this->tasks_[task_id].get(); |
215 | ICHECK(task->runner_futures.defined()); |
216 | Array<RunnerResult> results; |
217 | { |
218 | auto _ = Profiler::TimedScope("JoinRunnerFutures" ); |
219 | Array<RunnerFuture> futures = task->runner_futures.value(); |
220 | results.reserve(futures.size()); |
221 | for (RunnerFuture future : futures) { |
222 | results.push_back(future->Result()); |
223 | } |
224 | } |
225 | ICHECK(task->measure_candidates.defined()); |
226 | task->ctx->search_strategy.value()->NotifyRunnerResults(task->measure_candidates.value(), |
227 | results); |
228 | ICHECK(task->builder_results.defined()); |
229 | ICHECK_EQ(results.size(), task->measure_candidates.value().size()); |
230 | ICHECK_EQ(results.size(), task->builder_results.value().size()); |
231 | for (const MeasureCallback& callback : this->measure_callbacks_) { |
232 | callback->Apply(GetRef<TaskScheduler>(this), task_id, task->measure_candidates.value(), |
233 | task->builder_results.value(), results); |
234 | } |
235 | TaskCleanUp(task, task_id, results); |
236 | TVM_PY_LOG_CLEAR_SCREEN(this->logger); |
237 | TVM_PY_LOG(INFO, this->logger) << "[Updated] Task #" << task_id << ": " << task->ctx->task_name; |
238 | this->PrintTuningStatistics(); |
239 | return results; |
240 | } |
241 | |
242 | void TaskSchedulerNode::TouchTask(int task_id) { |
243 | TaskRecordNode* task = this->tasks_[task_id].get(); |
244 | if (!task->is_terminated && task->runner_futures.defined()) { |
245 | for (const RunnerFuture future : task->runner_futures.value()) { |
246 | if (!future->Done()) { |
247 | return; |
248 | } |
249 | } |
250 | this->JoinRunningTask(task_id); |
251 | } |
252 | } |
253 | |
254 | void TaskSchedulerNode::TerminateTask(int task_id) { |
255 | TaskRecordNode* task = this->tasks_[task_id].get(); |
256 | ICHECK(!task->is_terminated); |
257 | task->is_terminated = true; |
258 | --this->remaining_tasks_; |
259 | TVM_PY_LOG_CLEAR_SCREEN(this->logger); |
260 | TVM_PY_LOG(INFO, this->logger) << "Task #" << task_id |
261 | << " has finished. Remaining task(s): " << this->remaining_tasks_; |
262 | this->PrintTuningStatistics(); |
263 | } |
264 | |
265 | void TaskSchedulerNode::PrintTuningStatistics() { |
266 | std::ostringstream os; |
267 | int n_tasks = this->tasks_.size(); |
268 | int total_trials = 0; |
269 | double total_latency = 0.0; |
270 | support::TablePrinter p; |
271 | p.Row() << "ID" |
272 | << "Name" |
273 | << "FLOP" |
274 | << "Weight" |
275 | << "Speed (GFLOPS)" |
276 | << "Latency (us)" |
277 | << "Weighted Latency (us)" |
278 | << "Trials" |
279 | << "Done" ; |
280 | p.Separator(); |
281 | for (int i = 0; i < n_tasks; ++i) { |
282 | const TaskRecordNode* task = this->tasks_[i].get(); |
283 | auto row = p.Row(); |
284 | int trials = task->latency_ms.size(); |
285 | row << /*id=*/i << /*name=*/task->ctx->task_name.value() // |
286 | << /*flops=*/static_cast<int64_t>(task->flop) |
287 | << /*weight=*/static_cast<int>(task->task_weight); |
288 | double latency_ms = 1e9; |
289 | if (!task->latency_ms.empty()) { |
290 | latency_ms = *std::min_element(task->latency_ms.begin(), task->latency_ms.end()); |
291 | } |
292 | if (latency_ms >= 1e9) { |
293 | row << /*speed=*/"N/A" << /*latency=*/"N/A" << /*weighted_latency=*/"N/A" ; |
294 | } else { |
295 | latency_ms *= 1000.0; |
296 | double speed = task->flop / latency_ms / 1000.0; |
297 | double weighted_latency = latency_ms * task->task_weight; |
298 | row << /*speed=*/speed << /*latency=*/latency_ms << /*weighted_latency=*/weighted_latency; |
299 | total_latency += weighted_latency; |
300 | total_trials += trials; |
301 | } |
302 | row << trials; |
303 | if (task->is_terminated) { |
304 | row << "Y" ; |
305 | } else { |
306 | row << "" ; |
307 | } |
308 | } |
309 | p.Separator(); |
310 | |
311 | os << "\nTotal trials: " << total_trials // |
312 | << "\nTotal latency (us): " << total_latency // |
313 | << "\n" ; |
314 | |
315 | if (using_ipython()) { |
316 | print_interactive_table(p.AsStr()); |
317 | std::cout << os.str() << std::endl << std::flush; |
318 | TVM_PY_LOG(DEBUG, this->logger) << "\n" << p.AsStr() << os.str(); |
319 | } else { |
320 | TVM_PY_LOG(INFO, this->logger) << "\n" << p.AsStr() << os.str(); |
321 | } |
322 | } |
323 | |
324 | TaskScheduler TaskScheduler::PyTaskScheduler( |
325 | PackedFunc logger, PyTaskSchedulerNode::FNextTaskId f_next_task_id, |
326 | PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, PyTaskSchedulerNode::FTune f_tune) { |
327 | CHECK(f_next_task_id != nullptr) << "ValueError: next_task_id is not defined" ; |
328 | ObjectPtr<PyTaskSchedulerNode> n = make_object<PyTaskSchedulerNode>(); |
329 | n->logger = logger; |
330 | n->f_next_task_id = f_next_task_id; |
331 | n->f_join_running_task = f_join_running_task; |
332 | n->f_tune = f_tune; |
333 | return TaskScheduler(n); |
334 | } |
335 | |
336 | int PyTaskSchedulerNode::NextTaskId() { |
337 | CHECK(f_next_task_id != nullptr) << "PyTaskScheduler's NextTaskId method not implemented!" ; |
338 | return f_next_task_id(); |
339 | } |
340 | |
341 | Array<RunnerResult> PyTaskSchedulerNode::JoinRunningTask(int task_id) { |
342 | if (f_join_running_task == nullptr) { |
343 | return TaskSchedulerNode::JoinRunningTask(task_id); |
344 | } else { |
345 | return f_join_running_task(task_id); |
346 | } |
347 | } |
348 | |
349 | void PyTaskSchedulerNode::Tune(Array<TuneContext> tasks, Array<FloatImm> task_weights, |
350 | int max_trials_global, int max_trials_per_task, |
351 | int num_trials_per_iter, Builder builder, Runner runner, |
352 | Array<MeasureCallback> measure_callbacks, |
353 | Optional<Database> database, Optional<CostModel> cost_model) { |
354 | if (f_tune == nullptr) { |
355 | TaskSchedulerNode::Tune(tasks, task_weights, max_trials_global, max_trials_per_task, |
356 | num_trials_per_iter, builder, runner, measure_callbacks, database, |
357 | cost_model); |
358 | } else { |
359 | f_tune(tasks, task_weights, max_trials_global, max_trials_per_task, num_trials_per_iter, |
360 | builder, runner, measure_callbacks, database, cost_model); |
361 | } |
362 | } |
363 | |
364 | TVM_REGISTER_NODE_TYPE(TaskRecordNode); |
365 | TVM_REGISTER_OBJECT_TYPE(TaskSchedulerNode); |
366 | TVM_REGISTER_NODE_TYPE(PyTaskSchedulerNode); |
367 | TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerPyTaskScheduler" ) |
368 | .set_body_typed(TaskScheduler::PyTaskScheduler); |
369 | TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTune" ) |
370 | .set_body_method<TaskScheduler>(&TaskSchedulerNode::Tune); |
371 | TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerJoinRunningTask" ) |
372 | .set_body_method<TaskScheduler>(&TaskSchedulerNode::JoinRunningTask); |
373 | TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerNextTaskId" ) |
374 | .set_body_method<TaskScheduler>(&TaskSchedulerNode::NextTaskId); |
375 | TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTerminateTask" ) |
376 | .set_body_method<TaskScheduler>(&TaskSchedulerNode::TerminateTask); |
377 | TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTouchTask" ) |
378 | .set_body_method<TaskScheduler>(&TaskSchedulerNode::TouchTask); |
379 | TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerPrintTuningStatistics" ) |
380 | .set_body_method<TaskScheduler>(&TaskSchedulerNode::PrintTuningStatistics); |
381 | |
382 | } // namespace meta_schedule |
383 | } // namespace tvm |
384 | |