1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file auto_scheduler/measure.cc
22 * \brief Distributed measurement infrastructure to measure the runtime costs of tensor programs.
23 */
24
25#include <tvm/auto_scheduler/measure.h>
26#include <tvm/runtime/registry.h>
27
28#include <algorithm>
29
30#include "search_policy/empty_policy.h"
31#include "search_policy/sketch_policy.h"
32#include "utils.h"
33
34namespace tvm {
35namespace auto_scheduler {
36
37TVM_REGISTER_NODE_TYPE(MeasureInputNode);
38TVM_REGISTER_NODE_TYPE(BuildResultNode);
39TVM_REGISTER_NODE_TYPE(MeasureResultNode);
40TVM_REGISTER_OBJECT_TYPE(MeasureCallbackNode);
41TVM_REGISTER_OBJECT_TYPE(PythonBasedMeasureCallbackNode);
42TVM_REGISTER_OBJECT_TYPE(ProgramRunnerNode);
43TVM_REGISTER_OBJECT_TYPE(ProgramBuilderNode);
44TVM_REGISTER_OBJECT_TYPE(ProgramMeasurerNode);
45TVM_REGISTER_OBJECT_TYPE(LocalBuilderNode);
46TVM_REGISTER_OBJECT_TYPE(LocalRunnerNode);
47TVM_REGISTER_OBJECT_TYPE(RPCRunnerNode);
48
49static const char* ErrorNoToStr[] = {
50 "NoError",
51 "InstantiationError",
52 "CompileHostError",
53 "CompileDeviceError",
54 "RuntimeDeviceError",
55 "WrongAnswerError",
56 "BuildTimeoutError",
57 "RunTimeoutError",
58 "UnknownError",
59};
60
61/********** Measure input and result **********/
62MeasureInput::MeasureInput(SearchTask task, State state) {
63 auto node = make_object<MeasureInputNode>();
64 node->task = std::move(task);
65 node->state = std::move(state);
66 data_ = std::move(node);
67}
68
69MeasureInput MeasureInputNode::copy() const {
70 auto node = make_object<MeasureInputNode>();
71 node->task = task;
72 node->state = state;
73 return MeasureInput(node);
74}
75
76BuildResult::BuildResult(String filename, Array<te::Tensor> args, int error_no, String error_msg,
77 double time_cost) {
78 auto node = make_object<BuildResultNode>();
79 node->filename = std::move(filename);
80 node->args = std::move(args);
81 node->error_no = error_no;
82 node->error_msg = std::move(error_msg);
83 node->time_cost = time_cost;
84 data_ = std::move(node);
85}
86
87MeasureResult::MeasureResult(Array<PrimExpr> costs, int error_no, String error_msg, double all_cost,
88 double timestamp) {
89 auto node = make_object<MeasureResultNode>();
90 node->costs = std::move(costs);
91 node->error_no = error_no;
92 node->error_msg = std::move(error_msg);
93 node->all_cost = all_cost;
94 node->timestamp = timestamp;
95 data_ = std::move(node);
96}
97
98MeasureResult MeasureResultNode::copy() const {
99 auto node = make_object<MeasureResultNode>();
100 node->costs = costs;
101 node->error_no = error_no;
102 node->error_msg = error_msg;
103 node->all_cost = all_cost;
104 node->timestamp = timestamp;
105 return MeasureResult(node);
106}
107
108/********** LocalBuilder **********/
109LocalBuilder::LocalBuilder(int timeout, int n_parallel, const String& build_func) {
110 auto node = make_object<LocalBuilderNode>();
111 node->timeout = timeout;
112 node->n_parallel = n_parallel;
113 node->build_func = build_func;
114 data_ = std::move(node);
115}
116
117Array<BuildResult> LocalBuilderNode::Build(const Array<MeasureInput>& inputs, int verbose) {
118 if (const auto* f = runtime::Registry::Get("auto_scheduler.local_builder.build")) {
119 Array<BuildResult> results = (*f)(inputs, timeout, n_parallel, build_func, verbose);
120 return results;
121 }
122 LOG(FATAL) << "auto_scheduler.local_builder.build is not registered. "
123 << "This is a function registered in Python, "
124 << "make sure the TVM Python runtime has been loaded successfully.";
125 throw;
126}
127
128/********** LocalRunner **********/
129LocalRunner::LocalRunner(int timeout, int number, int repeat, int min_repeat_ms,
130 double cooldown_interval, bool enable_cpu_cache_flush, int device) {
131 ObjectPtr<LocalRunnerNode> node = make_object<LocalRunnerNode>();
132 node->timeout = timeout;
133 node->number = number;
134 node->repeat = repeat;
135 node->min_repeat_ms = min_repeat_ms;
136 node->cooldown_interval = cooldown_interval;
137 node->enable_cpu_cache_flush = enable_cpu_cache_flush;
138 node->device = device;
139 data_ = std::move(node);
140}
141
142Array<MeasureResult> LocalRunnerNode::Run(const Array<MeasureInput>& inputs,
143 const Array<BuildResult>& build_results, int verbose) {
144 if (const auto* f = runtime::Registry::Get("auto_scheduler.local_runner.run")) {
145 Array<MeasureResult> results =
146 (*f)(inputs, build_results, timeout, number, repeat, min_repeat_ms, cooldown_interval,
147 enable_cpu_cache_flush, verbose, device);
148 return results;
149 }
150 LOG(FATAL) << "auto_scheduler.local_runner.run is not registered. "
151 << "This is a function registered in Python, "
152 << "make sure the TVM Python runtime has been loaded successfully.";
153 throw;
154}
155
156/********** RPCRunner **********/
157RPCRunner::RPCRunner(const String& key, const String& host, int port, int priority, int n_parallel,
158 int timeout, int number, int repeat, int min_repeat_ms,
159 double cooldown_interval, bool enable_cpu_cache_flush, int device) {
160 auto node = make_object<RPCRunnerNode>();
161 node->key = key;
162 node->host = host;
163 node->port = port;
164 node->priority = priority;
165 node->timeout = timeout;
166 node->n_parallel = n_parallel;
167 node->number = number;
168 node->repeat = repeat;
169 node->min_repeat_ms = min_repeat_ms;
170 node->cooldown_interval = cooldown_interval;
171 node->enable_cpu_cache_flush = enable_cpu_cache_flush;
172 node->device = device;
173 data_ = std::move(node);
174}
175
176Array<MeasureResult> RPCRunnerNode::Run(const Array<MeasureInput>& inputs,
177 const Array<BuildResult>& build_results, int verbose) {
178 if (const auto* f = runtime::Registry::Get("auto_scheduler.rpc_runner.run")) {
179 Array<MeasureResult> results =
180 (*f)(inputs, build_results, key, host, port, priority, n_parallel, timeout, number, repeat,
181 min_repeat_ms, cooldown_interval, enable_cpu_cache_flush, verbose, device);
182 return results;
183 } else {
184 LOG(FATAL) << "auto_scheduler.rpc_runner.run is not registered. "
185 << "This is a function registered in Python, "
186 << "make sure the TVM Python runtime has been loaded successfully.";
187 }
188 return Array<MeasureResult>();
189}
190
191/********** MeasureCallback **********/
192PythonBasedMeasureCallback::PythonBasedMeasureCallback(PackedFunc callback_func) {
193 auto node = make_object<PythonBasedMeasureCallbackNode>();
194 node->callback_func = std::move(callback_func);
195 data_ = std::move(node);
196}
197
198void PythonBasedMeasureCallbackNode::Callback(const SearchPolicy& policy,
199 const Array<MeasureInput>& inputs,
200 const Array<MeasureResult>& results) {
201 if (auto* sketch_policy = static_cast<SketchPolicyNode*>(policy.operator->())) {
202 callback_func(GetRef<SketchPolicy>(sketch_policy), inputs, results);
203 } else if (auto* empty_policy = static_cast<EmptyPolicyNode*>(policy.operator->())) {
204 callback_func(GetRef<EmptyPolicy>(empty_policy), inputs, results);
205 } else {
206 LOG(FATAL) << "Unrecognized search policy type. Expect SketchPolicy or EmptyPolicy";
207 }
208}
209
210/********** ProgramMeasurer **********/
211ProgramMeasurer::ProgramMeasurer(ProgramBuilder builder, ProgramRunner runner,
212 Optional<Array<MeasureCallback>> callbacks, int verbose,
213 int max_continuous_error) {
214 auto node = make_object<ProgramMeasurerNode>();
215 node->builder = std::move(builder);
216 node->runner = std::move(runner);
217 node->callbacks = std::move(callbacks);
218 node->verbose = verbose;
219 node->max_continuous_error = max_continuous_error < 0
220 ? ProgramMeasurerNode::DEFAULT_MAX_CONTINUOUS_ERROR
221 : max_continuous_error;
222 data_ = std::move(node);
223}
224
225void ProgramMeasurerNode::Reset() {
226 ct = error_ct = 0;
227 best_flops.clear();
228 best_ct.clear();
229 best_state.clear();
230 has_valid.clear();
231}
232
233Array<MeasureResult> ProgramMeasurerNode::Measure(const SearchTask& task,
234 const SearchPolicy& policy,
235 const Array<MeasureInput>& inputs,
236 int batch_size) {
237 auto t_begin = std::chrono::high_resolution_clock::now();
238
239 Array<MeasureResult> results;
240 results.reserve(inputs.size());
241
242 if (batch_size == -1) {
243 // set default batch size
244 batch_size = builder->n_parallel * 2;
245 }
246
247 int old_verbosity = verbose;
248
249 StdCout(verbose) << "Get " << inputs.size() << " programs to measure:" << std::endl;
250
251 for (size_t i = 0; i < inputs.size(); i += batch_size) {
252 Array<MeasureInput> input_batch(inputs.begin() + i,
253 inputs.begin() + std::min(i + batch_size, inputs.size()));
254 Array<MeasureResult> result_batch;
255
256 // build and run
257 SilentMeasure(task, input_batch, &result_batch);
258
259 // update current best state according to the new measure result
260 for (size_t j = 0; j < input_batch.size(); ++j) {
261 const String& workload_key = input_batch[j]->task->workload_key;
262 double flops;
263
264 if (result_batch[j]->error_no == 0) {
265 flops = task->compute_dag->flop_ct / FloatArrayMean(result_batch[j]->costs);
266 error_ct = 0;
267 has_valid.insert(workload_key);
268 } else {
269 flops = 0.0;
270 error_ct++;
271 }
272
273 if (flops > best_flops[workload_key]) {
274 best_flops[workload_key] = flops;
275 best_state[workload_key] = input_batch[j]->state;
276 best_ct[workload_key] = ct;
277 }
278
279 ct++;
280 StdCout(verbose, 2) << std::fixed << std::setprecision(2) << Chars('=', 50) << "\n"
281 << "No: " << ct << "\tGFLOPS: " << flops / 1e9 << " / "
282 << best_flops[workload_key] / 1e9 << "\tresults: " << result_batch[j]
283 << "\n"
284 << Chars('=', 50) << "\n"
285 << input_batch[j]->state << "\n";
286 }
287
288 // Call callback functions
289 if (callbacks) {
290 for (const auto& callback : callbacks.value()) {
291 callback->Callback(policy, input_batch, result_batch);
292 }
293 }
294
295 // Store result batch
296 for (auto& res : result_batch) {
297 results.push_back(res);
298 }
299
300 if (error_ct > max_continuous_error) {
301 LOG(WARNING) << "Too many errors happened during tuning. Switching to debug mode."
302 << std::endl;
303 verbose = 2;
304 } else {
305 verbose = old_verbosity;
306 }
307 }
308
309 PrintTimeElapsed(t_begin, "measurement", verbose);
310
311 return results;
312}
313
314void ProgramMeasurerNode::SilentMeasure(const SearchTask& task, const Array<MeasureInput>& inputs,
315 Array<MeasureResult>* results) {
316 results->clear();
317 results->reserve(inputs.size());
318
319 // Call builder and runner
320 Array<BuildResult> build_res_batch = builder->Build(inputs, verbose);
321 Array<MeasureResult> result_batch = runner->Run(inputs, build_res_batch, verbose);
322
323 // Store result batch
324 for (auto& res : result_batch) {
325 results->push_back(res);
326 }
327}
328
329/********** Printing functions **********/
330TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
331 .set_dispatch<MeasureInputNode>([](const ObjectRef& ref, ReprPrinter* p) {
332 p->stream << "MeasureInput()";
333 });
334
335TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
336 .set_dispatch<MeasureResultNode>([](const ObjectRef& ref, ReprPrinter* p) {
337 auto* node = static_cast<const MeasureResultNode*>(ref.get());
338 if (node->error_no == static_cast<int>(MeasureErrorNO::kNoError)) {
339 p->stream << "MeasureResult(cost:[";
340 auto old_config = p->stream.precision(4);
341 for (size_t i = 0; i < node->costs.size(); ++i) {
342 auto pf = node->costs[i].as<FloatImmNode>();
343 ICHECK(pf != nullptr);
344 p->stream << pf->value;
345 if (i != node->costs.size() - 1) {
346 p->stream << ",";
347 }
348 }
349 p->stream.precision(old_config);
350 p->stream << "], ";
351 p->stream << "error_no:" << 0 << ", "
352 << "all_cost:" << node->all_cost << ", "
353 << "Tstamp:" << node->timestamp << ")";
354 } else {
355 p->stream << "MeasureResult("
356 << "error_type:" << ErrorNoToStr[node->error_no] << ", "
357 << "error_msg:" << node->error_msg << ", "
358 << "all_cost:" << node->all_cost << ", "
359 << "Tstamp:" << node->timestamp << ")";
360 }
361 });
362
363TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
364 .set_dispatch<BuildResultNode>([](const ObjectRef& ref, ReprPrinter* p) {
365 auto* node = static_cast<const BuildResultNode*>(ref.get());
366 p->stream << "BuildResult(" << node->filename << ", " << node->error_no << ", "
367 << node->time_cost << ")";
368 });
369
370/********** Measure interface API for ffi **********/
371TVM_REGISTER_GLOBAL("auto_scheduler.MeasureInput").set_body_typed([](SearchTask task, State state) {
372 return MeasureInput(task, state);
373});
374
375TVM_REGISTER_GLOBAL("auto_scheduler.BuildResult")
376 .set_body_typed([](String filename, Array<te::Tensor> args, int error_no, String error_msg,
377 double time_cost) {
378 return BuildResult(filename, args, error_no, error_msg, time_cost);
379 });
380
381TVM_REGISTER_GLOBAL("auto_scheduler.MeasureResult")
382 .set_body_typed([](Array<PrimExpr> costs, int error_no, String error_msg, double all_cost,
383 double timestamp) {
384 return MeasureResult(costs, error_no, error_msg, all_cost, timestamp);
385 });
386
387TVM_REGISTER_GLOBAL("auto_scheduler.PythonBasedMeasureCallback")
388 .set_body_typed([](PackedFunc callback_func) {
389 return PythonBasedMeasureCallback(callback_func);
390 });
391
392TVM_REGISTER_GLOBAL("auto_scheduler.ProgramMeasurer")
393 .set_body_typed([](ProgramBuilder builder, ProgramRunner runner,
394 Array<MeasureCallback> callbacks, int verbose, int max_continuous_error) {
395 return ProgramMeasurer(builder, runner, callbacks, verbose, max_continuous_error);
396 });
397
398TVM_REGISTER_GLOBAL("auto_scheduler.ProgramBuilderBuild")
399 .set_body_typed([](const ProgramBuilder& builder, const Array<MeasureInput>& inputs,
400 int verbose) { return builder->Build(inputs, verbose); });
401
402TVM_REGISTER_GLOBAL("auto_scheduler.ProgramRunnerRun")
403 .set_body_typed([](const ProgramRunner& runner, const Array<MeasureInput>& inputs,
404 const Array<BuildResult>& build_results,
405 int verbose) { return runner->Run(inputs, build_results, verbose); });
406
407TVM_REGISTER_GLOBAL("auto_scheduler.LocalBuilder")
408 .set_body_typed([](int timeout, int n_parallel, const String& build_func) {
409 return LocalBuilder(timeout, n_parallel, build_func);
410 });
411
412TVM_REGISTER_GLOBAL("auto_scheduler.LocalRunner")
413 .set_body_typed([](int timeout, int number, int repeat, int min_repeat_ms,
414 double cooldown_interval, bool enable_cpu_cache_flush, int device) {
415 return LocalRunner(timeout, number, repeat, min_repeat_ms, cooldown_interval,
416 enable_cpu_cache_flush, device);
417 });
418
419TVM_REGISTER_GLOBAL("auto_scheduler.RPCRunner")
420 .set_body_typed([](const String& key, const String& host, int port, int priority,
421 int n_parallel, int timeout, int number, int repeat, int min_repeat_ms,
422 double cooldown_interval, bool enable_cpu_cache_flush, int device) {
423 return RPCRunner(key, host, port, priority, n_parallel, timeout, number, repeat,
424 min_repeat_ms, cooldown_interval, enable_cpu_cache_flush, device);
425 });
426
427} // namespace auto_scheduler
428} // namespace tvm
429