1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file auto_scheduler/measure.cc |
22 | * \brief Distributed measurement infrastructure to measure the runtime costs of tensor programs. |
23 | */ |
24 | |
25 | #include <tvm/auto_scheduler/measure.h> |
26 | #include <tvm/runtime/registry.h> |
27 | |
28 | #include <algorithm> |
29 | |
30 | #include "search_policy/empty_policy.h" |
31 | #include "search_policy/sketch_policy.h" |
32 | #include "utils.h" |
33 | |
34 | namespace tvm { |
35 | namespace auto_scheduler { |
36 | |
37 | TVM_REGISTER_NODE_TYPE(MeasureInputNode); |
38 | TVM_REGISTER_NODE_TYPE(BuildResultNode); |
39 | TVM_REGISTER_NODE_TYPE(MeasureResultNode); |
40 | TVM_REGISTER_OBJECT_TYPE(MeasureCallbackNode); |
41 | TVM_REGISTER_OBJECT_TYPE(PythonBasedMeasureCallbackNode); |
42 | TVM_REGISTER_OBJECT_TYPE(ProgramRunnerNode); |
43 | TVM_REGISTER_OBJECT_TYPE(ProgramBuilderNode); |
44 | TVM_REGISTER_OBJECT_TYPE(ProgramMeasurerNode); |
45 | TVM_REGISTER_OBJECT_TYPE(LocalBuilderNode); |
46 | TVM_REGISTER_OBJECT_TYPE(LocalRunnerNode); |
47 | TVM_REGISTER_OBJECT_TYPE(RPCRunnerNode); |
48 | |
49 | static const char* ErrorNoToStr[] = { |
50 | "NoError" , |
51 | "InstantiationError" , |
52 | "CompileHostError" , |
53 | "CompileDeviceError" , |
54 | "RuntimeDeviceError" , |
55 | "WrongAnswerError" , |
56 | "BuildTimeoutError" , |
57 | "RunTimeoutError" , |
58 | "UnknownError" , |
59 | }; |
60 | |
61 | /********** Measure input and result **********/ |
62 | MeasureInput::MeasureInput(SearchTask task, State state) { |
63 | auto node = make_object<MeasureInputNode>(); |
64 | node->task = std::move(task); |
65 | node->state = std::move(state); |
66 | data_ = std::move(node); |
67 | } |
68 | |
69 | MeasureInput MeasureInputNode::copy() const { |
70 | auto node = make_object<MeasureInputNode>(); |
71 | node->task = task; |
72 | node->state = state; |
73 | return MeasureInput(node); |
74 | } |
75 | |
76 | BuildResult::BuildResult(String filename, Array<te::Tensor> args, int error_no, String error_msg, |
77 | double time_cost) { |
78 | auto node = make_object<BuildResultNode>(); |
79 | node->filename = std::move(filename); |
80 | node->args = std::move(args); |
81 | node->error_no = error_no; |
82 | node->error_msg = std::move(error_msg); |
83 | node->time_cost = time_cost; |
84 | data_ = std::move(node); |
85 | } |
86 | |
87 | MeasureResult::MeasureResult(Array<PrimExpr> costs, int error_no, String error_msg, double all_cost, |
88 | double timestamp) { |
89 | auto node = make_object<MeasureResultNode>(); |
90 | node->costs = std::move(costs); |
91 | node->error_no = error_no; |
92 | node->error_msg = std::move(error_msg); |
93 | node->all_cost = all_cost; |
94 | node->timestamp = timestamp; |
95 | data_ = std::move(node); |
96 | } |
97 | |
98 | MeasureResult MeasureResultNode::copy() const { |
99 | auto node = make_object<MeasureResultNode>(); |
100 | node->costs = costs; |
101 | node->error_no = error_no; |
102 | node->error_msg = error_msg; |
103 | node->all_cost = all_cost; |
104 | node->timestamp = timestamp; |
105 | return MeasureResult(node); |
106 | } |
107 | |
108 | /********** LocalBuilder **********/ |
109 | LocalBuilder::LocalBuilder(int timeout, int n_parallel, const String& build_func) { |
110 | auto node = make_object<LocalBuilderNode>(); |
111 | node->timeout = timeout; |
112 | node->n_parallel = n_parallel; |
113 | node->build_func = build_func; |
114 | data_ = std::move(node); |
115 | } |
116 | |
117 | Array<BuildResult> LocalBuilderNode::Build(const Array<MeasureInput>& inputs, int verbose) { |
118 | if (const auto* f = runtime::Registry::Get("auto_scheduler.local_builder.build" )) { |
119 | Array<BuildResult> results = (*f)(inputs, timeout, n_parallel, build_func, verbose); |
120 | return results; |
121 | } |
122 | LOG(FATAL) << "auto_scheduler.local_builder.build is not registered. " |
123 | << "This is a function registered in Python, " |
124 | << "make sure the TVM Python runtime has been loaded successfully." ; |
125 | throw; |
126 | } |
127 | |
128 | /********** LocalRunner **********/ |
129 | LocalRunner::LocalRunner(int timeout, int number, int repeat, int min_repeat_ms, |
130 | double cooldown_interval, bool enable_cpu_cache_flush, int device) { |
131 | ObjectPtr<LocalRunnerNode> node = make_object<LocalRunnerNode>(); |
132 | node->timeout = timeout; |
133 | node->number = number; |
134 | node->repeat = repeat; |
135 | node->min_repeat_ms = min_repeat_ms; |
136 | node->cooldown_interval = cooldown_interval; |
137 | node->enable_cpu_cache_flush = enable_cpu_cache_flush; |
138 | node->device = device; |
139 | data_ = std::move(node); |
140 | } |
141 | |
142 | Array<MeasureResult> LocalRunnerNode::Run(const Array<MeasureInput>& inputs, |
143 | const Array<BuildResult>& build_results, int verbose) { |
144 | if (const auto* f = runtime::Registry::Get("auto_scheduler.local_runner.run" )) { |
145 | Array<MeasureResult> results = |
146 | (*f)(inputs, build_results, timeout, number, repeat, min_repeat_ms, cooldown_interval, |
147 | enable_cpu_cache_flush, verbose, device); |
148 | return results; |
149 | } |
150 | LOG(FATAL) << "auto_scheduler.local_runner.run is not registered. " |
151 | << "This is a function registered in Python, " |
152 | << "make sure the TVM Python runtime has been loaded successfully." ; |
153 | throw; |
154 | } |
155 | |
156 | /********** RPCRunner **********/ |
157 | RPCRunner::RPCRunner(const String& key, const String& host, int port, int priority, int n_parallel, |
158 | int timeout, int number, int repeat, int min_repeat_ms, |
159 | double cooldown_interval, bool enable_cpu_cache_flush, int device) { |
160 | auto node = make_object<RPCRunnerNode>(); |
161 | node->key = key; |
162 | node->host = host; |
163 | node->port = port; |
164 | node->priority = priority; |
165 | node->timeout = timeout; |
166 | node->n_parallel = n_parallel; |
167 | node->number = number; |
168 | node->repeat = repeat; |
169 | node->min_repeat_ms = min_repeat_ms; |
170 | node->cooldown_interval = cooldown_interval; |
171 | node->enable_cpu_cache_flush = enable_cpu_cache_flush; |
172 | node->device = device; |
173 | data_ = std::move(node); |
174 | } |
175 | |
176 | Array<MeasureResult> RPCRunnerNode::Run(const Array<MeasureInput>& inputs, |
177 | const Array<BuildResult>& build_results, int verbose) { |
178 | if (const auto* f = runtime::Registry::Get("auto_scheduler.rpc_runner.run" )) { |
179 | Array<MeasureResult> results = |
180 | (*f)(inputs, build_results, key, host, port, priority, n_parallel, timeout, number, repeat, |
181 | min_repeat_ms, cooldown_interval, enable_cpu_cache_flush, verbose, device); |
182 | return results; |
183 | } else { |
184 | LOG(FATAL) << "auto_scheduler.rpc_runner.run is not registered. " |
185 | << "This is a function registered in Python, " |
186 | << "make sure the TVM Python runtime has been loaded successfully." ; |
187 | } |
188 | return Array<MeasureResult>(); |
189 | } |
190 | |
191 | /********** MeasureCallback **********/ |
192 | PythonBasedMeasureCallback::PythonBasedMeasureCallback(PackedFunc callback_func) { |
193 | auto node = make_object<PythonBasedMeasureCallbackNode>(); |
194 | node->callback_func = std::move(callback_func); |
195 | data_ = std::move(node); |
196 | } |
197 | |
198 | void PythonBasedMeasureCallbackNode::Callback(const SearchPolicy& policy, |
199 | const Array<MeasureInput>& inputs, |
200 | const Array<MeasureResult>& results) { |
201 | if (auto* sketch_policy = static_cast<SketchPolicyNode*>(policy.operator->())) { |
202 | callback_func(GetRef<SketchPolicy>(sketch_policy), inputs, results); |
203 | } else if (auto* empty_policy = static_cast<EmptyPolicyNode*>(policy.operator->())) { |
204 | callback_func(GetRef<EmptyPolicy>(empty_policy), inputs, results); |
205 | } else { |
206 | LOG(FATAL) << "Unrecognized search policy type. Expect SketchPolicy or EmptyPolicy" ; |
207 | } |
208 | } |
209 | |
210 | /********** ProgramMeasurer **********/ |
211 | ProgramMeasurer::ProgramMeasurer(ProgramBuilder builder, ProgramRunner runner, |
212 | Optional<Array<MeasureCallback>> callbacks, int verbose, |
213 | int max_continuous_error) { |
214 | auto node = make_object<ProgramMeasurerNode>(); |
215 | node->builder = std::move(builder); |
216 | node->runner = std::move(runner); |
217 | node->callbacks = std::move(callbacks); |
218 | node->verbose = verbose; |
219 | node->max_continuous_error = max_continuous_error < 0 |
220 | ? ProgramMeasurerNode::DEFAULT_MAX_CONTINUOUS_ERROR |
221 | : max_continuous_error; |
222 | data_ = std::move(node); |
223 | } |
224 | |
225 | void ProgramMeasurerNode::Reset() { |
226 | ct = error_ct = 0; |
227 | best_flops.clear(); |
228 | best_ct.clear(); |
229 | best_state.clear(); |
230 | has_valid.clear(); |
231 | } |
232 | |
233 | Array<MeasureResult> ProgramMeasurerNode::Measure(const SearchTask& task, |
234 | const SearchPolicy& policy, |
235 | const Array<MeasureInput>& inputs, |
236 | int batch_size) { |
237 | auto t_begin = std::chrono::high_resolution_clock::now(); |
238 | |
239 | Array<MeasureResult> results; |
240 | results.reserve(inputs.size()); |
241 | |
242 | if (batch_size == -1) { |
243 | // set default batch size |
244 | batch_size = builder->n_parallel * 2; |
245 | } |
246 | |
247 | int old_verbosity = verbose; |
248 | |
249 | StdCout(verbose) << "Get " << inputs.size() << " programs to measure:" << std::endl; |
250 | |
251 | for (size_t i = 0; i < inputs.size(); i += batch_size) { |
252 | Array<MeasureInput> input_batch(inputs.begin() + i, |
253 | inputs.begin() + std::min(i + batch_size, inputs.size())); |
254 | Array<MeasureResult> result_batch; |
255 | |
256 | // build and run |
257 | SilentMeasure(task, input_batch, &result_batch); |
258 | |
259 | // update current best state according to the new measure result |
260 | for (size_t j = 0; j < input_batch.size(); ++j) { |
261 | const String& workload_key = input_batch[j]->task->workload_key; |
262 | double flops; |
263 | |
264 | if (result_batch[j]->error_no == 0) { |
265 | flops = task->compute_dag->flop_ct / FloatArrayMean(result_batch[j]->costs); |
266 | error_ct = 0; |
267 | has_valid.insert(workload_key); |
268 | } else { |
269 | flops = 0.0; |
270 | error_ct++; |
271 | } |
272 | |
273 | if (flops > best_flops[workload_key]) { |
274 | best_flops[workload_key] = flops; |
275 | best_state[workload_key] = input_batch[j]->state; |
276 | best_ct[workload_key] = ct; |
277 | } |
278 | |
279 | ct++; |
280 | StdCout(verbose, 2) << std::fixed << std::setprecision(2) << Chars('=', 50) << "\n" |
281 | << "No: " << ct << "\tGFLOPS: " << flops / 1e9 << " / " |
282 | << best_flops[workload_key] / 1e9 << "\tresults: " << result_batch[j] |
283 | << "\n" |
284 | << Chars('=', 50) << "\n" |
285 | << input_batch[j]->state << "\n" ; |
286 | } |
287 | |
288 | // Call callback functions |
289 | if (callbacks) { |
290 | for (const auto& callback : callbacks.value()) { |
291 | callback->Callback(policy, input_batch, result_batch); |
292 | } |
293 | } |
294 | |
295 | // Store result batch |
296 | for (auto& res : result_batch) { |
297 | results.push_back(res); |
298 | } |
299 | |
300 | if (error_ct > max_continuous_error) { |
301 | LOG(WARNING) << "Too many errors happened during tuning. Switching to debug mode." |
302 | << std::endl; |
303 | verbose = 2; |
304 | } else { |
305 | verbose = old_verbosity; |
306 | } |
307 | } |
308 | |
309 | PrintTimeElapsed(t_begin, "measurement" , verbose); |
310 | |
311 | return results; |
312 | } |
313 | |
314 | void ProgramMeasurerNode::SilentMeasure(const SearchTask& task, const Array<MeasureInput>& inputs, |
315 | Array<MeasureResult>* results) { |
316 | results->clear(); |
317 | results->reserve(inputs.size()); |
318 | |
319 | // Call builder and runner |
320 | Array<BuildResult> build_res_batch = builder->Build(inputs, verbose); |
321 | Array<MeasureResult> result_batch = runner->Run(inputs, build_res_batch, verbose); |
322 | |
323 | // Store result batch |
324 | for (auto& res : result_batch) { |
325 | results->push_back(res); |
326 | } |
327 | } |
328 | |
329 | /********** Printing functions **********/ |
330 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
331 | .set_dispatch<MeasureInputNode>([](const ObjectRef& ref, ReprPrinter* p) { |
332 | p->stream << "MeasureInput()" ; |
333 | }); |
334 | |
335 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
336 | .set_dispatch<MeasureResultNode>([](const ObjectRef& ref, ReprPrinter* p) { |
337 | auto* node = static_cast<const MeasureResultNode*>(ref.get()); |
338 | if (node->error_no == static_cast<int>(MeasureErrorNO::kNoError)) { |
339 | p->stream << "MeasureResult(cost:[" ; |
340 | auto old_config = p->stream.precision(4); |
341 | for (size_t i = 0; i < node->costs.size(); ++i) { |
342 | auto pf = node->costs[i].as<FloatImmNode>(); |
343 | ICHECK(pf != nullptr); |
344 | p->stream << pf->value; |
345 | if (i != node->costs.size() - 1) { |
346 | p->stream << "," ; |
347 | } |
348 | } |
349 | p->stream.precision(old_config); |
350 | p->stream << "], " ; |
351 | p->stream << "error_no:" << 0 << ", " |
352 | << "all_cost:" << node->all_cost << ", " |
353 | << "Tstamp:" << node->timestamp << ")" ; |
354 | } else { |
355 | p->stream << "MeasureResult(" |
356 | << "error_type:" << ErrorNoToStr[node->error_no] << ", " |
357 | << "error_msg:" << node->error_msg << ", " |
358 | << "all_cost:" << node->all_cost << ", " |
359 | << "Tstamp:" << node->timestamp << ")" ; |
360 | } |
361 | }); |
362 | |
363 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
364 | .set_dispatch<BuildResultNode>([](const ObjectRef& ref, ReprPrinter* p) { |
365 | auto* node = static_cast<const BuildResultNode*>(ref.get()); |
366 | p->stream << "BuildResult(" << node->filename << ", " << node->error_no << ", " |
367 | << node->time_cost << ")" ; |
368 | }); |
369 | |
370 | /********** Measure interface API for ffi **********/ |
371 | TVM_REGISTER_GLOBAL("auto_scheduler.MeasureInput" ).set_body_typed([](SearchTask task, State state) { |
372 | return MeasureInput(task, state); |
373 | }); |
374 | |
375 | TVM_REGISTER_GLOBAL("auto_scheduler.BuildResult" ) |
376 | .set_body_typed([](String filename, Array<te::Tensor> args, int error_no, String error_msg, |
377 | double time_cost) { |
378 | return BuildResult(filename, args, error_no, error_msg, time_cost); |
379 | }); |
380 | |
381 | TVM_REGISTER_GLOBAL("auto_scheduler.MeasureResult" ) |
382 | .set_body_typed([](Array<PrimExpr> costs, int error_no, String error_msg, double all_cost, |
383 | double timestamp) { |
384 | return MeasureResult(costs, error_no, error_msg, all_cost, timestamp); |
385 | }); |
386 | |
387 | TVM_REGISTER_GLOBAL("auto_scheduler.PythonBasedMeasureCallback" ) |
388 | .set_body_typed([](PackedFunc callback_func) { |
389 | return PythonBasedMeasureCallback(callback_func); |
390 | }); |
391 | |
392 | TVM_REGISTER_GLOBAL("auto_scheduler.ProgramMeasurer" ) |
393 | .set_body_typed([](ProgramBuilder builder, ProgramRunner runner, |
394 | Array<MeasureCallback> callbacks, int verbose, int max_continuous_error) { |
395 | return ProgramMeasurer(builder, runner, callbacks, verbose, max_continuous_error); |
396 | }); |
397 | |
398 | TVM_REGISTER_GLOBAL("auto_scheduler.ProgramBuilderBuild" ) |
399 | .set_body_typed([](const ProgramBuilder& builder, const Array<MeasureInput>& inputs, |
400 | int verbose) { return builder->Build(inputs, verbose); }); |
401 | |
402 | TVM_REGISTER_GLOBAL("auto_scheduler.ProgramRunnerRun" ) |
403 | .set_body_typed([](const ProgramRunner& runner, const Array<MeasureInput>& inputs, |
404 | const Array<BuildResult>& build_results, |
405 | int verbose) { return runner->Run(inputs, build_results, verbose); }); |
406 | |
407 | TVM_REGISTER_GLOBAL("auto_scheduler.LocalBuilder" ) |
408 | .set_body_typed([](int timeout, int n_parallel, const String& build_func) { |
409 | return LocalBuilder(timeout, n_parallel, build_func); |
410 | }); |
411 | |
412 | TVM_REGISTER_GLOBAL("auto_scheduler.LocalRunner" ) |
413 | .set_body_typed([](int timeout, int number, int repeat, int min_repeat_ms, |
414 | double cooldown_interval, bool enable_cpu_cache_flush, int device) { |
415 | return LocalRunner(timeout, number, repeat, min_repeat_ms, cooldown_interval, |
416 | enable_cpu_cache_flush, device); |
417 | }); |
418 | |
419 | TVM_REGISTER_GLOBAL("auto_scheduler.RPCRunner" ) |
420 | .set_body_typed([](const String& key, const String& host, int port, int priority, |
421 | int n_parallel, int timeout, int number, int repeat, int min_repeat_ms, |
422 | double cooldown_interval, bool enable_cpu_cache_flush, int device) { |
423 | return RPCRunner(key, host, port, priority, n_parallel, timeout, number, repeat, |
424 | min_repeat_ms, cooldown_interval, enable_cpu_cache_flush, device); |
425 | }); |
426 | |
427 | } // namespace auto_scheduler |
428 | } // namespace tvm |
429 | |