1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file auto_scheduler/measure.h
22 * \brief Distributed measurement infrastructure to measure the runtime costs of tensor programs.
23 * These functions are responsible for building the tvm module, uploading it to remote devices,
24 * recording the running time costs, and checking the correctness of the output.
25 *
26 * The measurement is separated into two steps: build and run.
27 * A builder builds the executable binary files and a runner runs the binary files to get the
28 * measurement results. The flow of data structures is
29 *
30 * `ProgramBuilder` `ProgramRunner`
31 * `MeasureInput` -----------------> `BuildResult` ----------------> `MeasureResult`
32 *
33 * The core functions is implemented in python to utilize python's multiprocessing
34 * and error handling (see also `python/tvm/auto_scheduler/measure.py`).
35 * This c++ file is just a wrapper for the python functions.
36 */
37
38#ifndef TVM_AUTO_SCHEDULER_MEASURE_H_
39#define TVM_AUTO_SCHEDULER_MEASURE_H_
40
41#include <tvm/auto_scheduler/loop_state.h>
42#include <tvm/auto_scheduler/search_task.h>
43
44#include <string>
45#include <unordered_map>
46#include <unordered_set>
47#include <utility>
48
49namespace tvm {
50namespace auto_scheduler {
51
52class SearchPolicy;
53class MeasureInput;
54class MeasureResult;
55
56/*! \brief The error code of one measurement */
57enum class MeasureErrorNO : int {
58 /*! \brief No error. */
59 kNoError = 0,
60 /*! \brief Errors happen when apply transform steps from init state. */
61 kInstantiationError = 1,
62 /*! \brief Errors happen when compiling code on host. (when build module) */
63 kCompileHostError = 2,
64 /*! \brief Errors happen when compiling code on device. (when load module) */
65 kCompileDeviceError = 3,
66 /*! \brief Errors happen when run program on device. */
67 kRuntimeDeviceError = 4,
68 /*! \brief Answer is wrong when compared to a reference output. */
69 kWrongAnswerError = 5,
70 /*! \brief Timeout during compilation. */
71 kBuildTimeoutError = 6,
72 /*! \brief Timeout during run. */
73 kRunTimeoutError = 7,
74 /*! \brief Unknown error. */
75 kUnknownError = 8,
76};
77
78// Inputs and results of one measurement
79
80/*! \brief Store the input of a measurement */
81class MeasureInputNode : public Object {
82 public:
83 /*! \brief The search task. */
84 SearchTask task;
85 /*! \brief The program state to be measured. */
86 State state;
87
88 void VisitAttrs(tvm::AttrVisitor* v) {
89 v->Visit("task", &task);
90 v->Visit("state", &state);
91 }
92
93 /*! \brief Do shallow copy. */
94 MeasureInput copy() const;
95
96 static constexpr const char* _type_key = "auto_scheduler.MeasureInput";
97 TVM_DECLARE_FINAL_OBJECT_INFO(MeasureInputNode, Object);
98};
99
100/*!
101 * \brief Managed reference to MeasureInputNode.
102 * \sa MeasureInputNode
103 */
104class MeasureInput : public ObjectRef {
105 public:
106 /*!
107 * \brief The constructor.
108 * \param task The SearchTask of this measure.
109 * \param state The State to be measured.
110 */
111 MeasureInput(SearchTask task, State state);
112
113 TVM_DEFINE_OBJECT_REF_METHODS(MeasureInput, ObjectRef, MeasureInputNode);
114};
115
116/*! \brief Store the result of a build. */
117class BuildResultNode : public Object {
118 public:
119 /*! \brief The filename of built binary file. */
120 String filename;
121 /*! \brief The arguments. */
122 Array<te::Tensor> args;
123 /*! \brief The error code. (0 means no error, see MeasureErrorNO) */
124 int error_no;
125 /*! \brief The error message if there is any error. */
126 String error_msg;
127 /*! \brief The time cost of build. */
128 double time_cost;
129
130 void VisitAttrs(tvm::AttrVisitor* v) {
131 v->Visit("filename", &filename);
132 v->Visit("args", &args);
133 v->Visit("error_no", &error_no);
134 v->Visit("error_msg", &error_msg);
135 v->Visit("time_cost", &time_cost);
136 }
137
138 static constexpr const char* _type_key = "auto_scheduler.BuildResult";
139 TVM_DECLARE_FINAL_OBJECT_INFO(BuildResultNode, Object);
140};
141
142/*!
143 * \brief Managed reference to BuildResultNode.
144 * \sa BuildResultNode
145 */
146class BuildResult : public ObjectRef {
147 public:
148 /*!
149 * \brief The constructor.
150 * \param filename The filename of built binary file.
151 * \param args The arguments.
152 * \param error_no The error code.
153 * \param error_msg The error message if there is any error.
154 * \param time_cost The time cost of build.
155 */
156 BuildResult(String filename, Array<te::Tensor> args, int error_no, String error_msg,
157 double time_cost);
158 TVM_DEFINE_OBJECT_REF_METHODS(BuildResult, ObjectRef, BuildResultNode);
159};
160
161/*! \brief Store the results of a measurement. */
162class MeasureResultNode : public Object {
163 public:
164 /*! \brief The time costs of execution. */
165 Array<PrimExpr> costs;
166 /*! \brief The error code. (0 means no error, see MeasureErrorNO) */
167 int error_no;
168 /*! \brief The error message if there is any error. */
169 String error_msg;
170 /*! \brief The time cost of build and run. */
171 double all_cost;
172 /*! \brief The time stamps of this measurement. */
173 double timestamp;
174
175 void VisitAttrs(tvm::AttrVisitor* v) {
176 v->Visit("costs", &costs);
177 v->Visit("error_no", &error_no);
178 v->Visit("error_msg", &error_msg);
179 v->Visit("all_cost", &all_cost);
180 v->Visit("timestamp", &timestamp);
181 }
182
183 /*! \brief Do shallow copy. */
184 MeasureResult copy() const;
185
186 static constexpr const char* _type_key = "auto_scheduler.MeasureResult";
187 TVM_DECLARE_FINAL_OBJECT_INFO(MeasureResultNode, Object);
188};
189
190/*!
191 * \brief Managed reference to MeasureResultNode.
192 * \sa MeasureResultNode
193 */
194class MeasureResult : public ObjectRef {
195 public:
196 /*!
197 * \brief The constructor.
198 * \param costs The time costs of execution.
199 * \param error_no The error code.
200 * \param error_msg The error message if there is any error.
201 * \param all_cost The time cost of build and run.
202 * \param timestamp The time stamps of this measurement.
203 */
204 MeasureResult(Array<PrimExpr> costs, int error_no, String error_msg, double all_cost,
205 double timestamp);
206
207 TVM_DEFINE_OBJECT_REF_METHODS(MeasureResult, ObjectRef, MeasureResultNode);
208};
209
210/*! \brief Bass class of measurement callbacks */
211class MeasureCallbackNode : public Object {
212 public:
213 /*!
214 * \brief Callback function that will be called on measurement input/result pairs
215 * after each measurement batch.
216 * \param policy The current search policy.
217 * \param inputs An Array of MeasureInput.
218 * \param results An Array of MeasureResult.
219 */
220 virtual void Callback(const SearchPolicy& policy, const Array<MeasureInput>& inputs,
221 const Array<MeasureResult>& results) = 0;
222 static constexpr const char* _type_key = "auto_scheduler.MeasureCallback";
223 TVM_DECLARE_BASE_OBJECT_INFO(MeasureCallbackNode, Object);
224};
225
226/*!
227 * \brief Managed reference to MeasureCallbackNode.
228 * \sa MeasureCallbackNode
229 */
230class MeasureCallback : public ObjectRef {
231 public:
232 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MeasureCallback, ObjectRef, MeasureCallbackNode);
233};
234
235/*! \brief A wrapper for measure callback defined by python code
236 * This class will call functions defined in the python */
237class PythonBasedMeasureCallbackNode : public MeasureCallbackNode {
238 public:
239 /*! \brief Pointer to the callback function in python */
240 PackedFunc callback_func;
241
242 void Callback(const SearchPolicy& policy, const Array<MeasureInput>& inputs,
243 const Array<MeasureResult>& results) final;
244 static constexpr const char* _type_key = "auto_scheduler.PythonBasedMeasureCallback";
245 TVM_DECLARE_FINAL_OBJECT_INFO(PythonBasedMeasureCallbackNode, MeasureCallbackNode);
246};
247
248/*!
249 * \brief Managed reference to PythonBasedMeasureCallbackNode.
250 * \sa PythonBasedMeasureCallbackNode
251 */
252class PythonBasedMeasureCallback : public MeasureCallback {
253 public:
254 /*!
255 * \brief The constructor.
256 * \param callback_func The pointer to the callback function defined in python
257 */
258 explicit PythonBasedMeasureCallback(PackedFunc callback_func);
259
260 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PythonBasedMeasureCallback, MeasureCallback,
261 PythonBasedMeasureCallbackNode);
262};
263
264// The base class of ProgramBuilders and ProgramRunners.
265
266/*! \brief ProgramBuilder that builds the programs */
267class ProgramBuilderNode : public Object {
268 public:
269 /*! \brief The number of build processes to run in parallel */
270 int n_parallel;
271 /*! \brief Timeout of a build */
272 int timeout;
273
274 /*!
275 * \brief Build programs and return results.
276 * \param inputs An Array of MeasureInput.
277 * \param verbose Verbosity level. 0 for silent, 1 to output information during program
278 * building.
279 * \return An Array of MeasureResult.
280 */
281 virtual Array<BuildResult> Build(const Array<MeasureInput>& inputs, int verbose) = 0;
282
283 static constexpr const char* _type_key = "auto_scheduler.ProgramBuilder";
284 TVM_DECLARE_BASE_OBJECT_INFO(ProgramBuilderNode, Object);
285};
286
287/*!
288 * \brief Managed reference to ProgramBuilderNode.
289 * \sa ProgramBuilderNode
290 */
291class ProgramBuilder : public ObjectRef {
292 public:
293 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ProgramBuilder, ObjectRef, ProgramBuilderNode);
294};
295
296/*! \brief ProgramRunner that runs the built programs and measure the time cost. */
297class ProgramRunnerNode : public Object {
298 public:
299 /*! \brief Timeout of a run. */
300 int timeout;
301 /*! \brief The number of times to run the generated code for taking average. */
302 int number;
303 /*! \brief The number of times to repeat the measurement. */
304 int repeat;
305 /*! \brief The minimum duration of one repeat in milliseconds. */
306 int min_repeat_ms;
307 /*! \brief The cool down interval between two measurements. */
308 double cooldown_interval;
309 /*! \brief Whether to flush cache on CPU between repeated measurements. */
310 bool enable_cpu_cache_flush;
311 /*! \brief Which device to run on if multiple are avaialble. */
312 int device;
313
314 /*!
315 * \brief Run measurement and return results.
316 * \param inputs An Array of MeasureInput.
317 * \param build_results An Array of BuildResult.
318 * \param verbose Verbosity level. 0 for silent, 1 to output information during program
319 * running.
320 * \return An Array of MeasureResult.
321 */
322 virtual Array<MeasureResult> Run(const Array<MeasureInput>& inputs,
323 const Array<BuildResult>& build_results, int verbose) = 0;
324
325 static constexpr const char* _type_key = "auto_scheduler.ProgramRunner";
326 TVM_DECLARE_BASE_OBJECT_INFO(ProgramRunnerNode, Object);
327};
328
329/*!
330 * \brief Managed reference to ProgramRunnerNode.
331 * \sa ProgramRunnerNode
332 */
333class ProgramRunner : public ObjectRef {
334 public:
335 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ProgramRunner, ObjectRef, ProgramRunnerNode);
336};
337
338// Implementation of various builders and runners
339
340/*! \brief LocalBuilder use local CPU cores to build programs in parallel */
341class LocalBuilderNode : public ProgramBuilderNode {
342 public:
343 /*! \brief Build function. */
344 String build_func;
345
346 Array<BuildResult> Build(const Array<MeasureInput>& inputs, int verbose) final;
347
348 static constexpr const char* _type_key = "auto_scheduler.LocalBuilder";
349 TVM_DECLARE_FINAL_OBJECT_INFO(LocalBuilderNode, ProgramBuilderNode);
350};
351
352/*!
353 * \brief Managed reference to LocalBuilderNode.
354 * \sa LocalBuilderNode
355 */
356class LocalBuilder : public ProgramBuilder {
357 public:
358 /*!
359 * \brief The constructor.
360 * \param timeout The timeout limit (in second) for each build thread.
361 * This will be used in a wrapper of the multiprocessing.Process.join().
362 * \param n_parallel The number of threads used to build in parallel.
363 * \param build_func The name of the registered build function.
364 */
365 LocalBuilder(int timeout, int n_parallel, const String& build_func);
366
367 TVM_DEFINE_OBJECT_REF_METHODS(LocalBuilder, ProgramBuilder, LocalBuilderNode);
368};
369
370/*! \brief LocalRunner that uses local CPU/GPU to measure the time cost of programs */
371class LocalRunnerNode : public ProgramRunnerNode {
372 public:
373 Array<MeasureResult> Run(const Array<MeasureInput>& inputs,
374 const Array<BuildResult>& build_results, int verbose) final;
375
376 static constexpr const char* _type_key = "auto_scheduler.LocalRunner";
377 TVM_DECLARE_FINAL_OBJECT_INFO(LocalRunnerNode, ProgramRunnerNode);
378};
379
380/*!
381 * \brief Managed reference to LocalRunnerNode.
382 * \sa LocalRunnerNode
383 */
384class LocalRunner : public ProgramRunner {
385 public:
386 /*!
387 * \brief The constructor. See the corresponding class in python/tvm/auto_scheduler/measure.py
388 * for more detailed parameter explanation.
389 * \param timeout The timeout limit (in second) for each run.
390 * This is used in a wrapper of the multiprocessing.Process.join().
391 * \param number The number of times to run the generated code for taking average.
392 * \param repeat The number of times to repeat the measurement.
393 * \param min_repeat_ms The minimum duration of one repeat in milliseconds.
394 * \param cooldown_interval The cool down interval between two measurements.
395 * \param enable_cpu_cache_flush Whether to flush cache on CPU between repeated measurements.
396 * \param device Which device to run on if multiple are available.
397 */
398 LocalRunner(int timeout, int number, int repeat, int min_repeat_ms, double cooldown_interval,
399 bool enable_cpu_cache_flush, int device);
400
401 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(LocalRunner, ProgramRunner, LocalRunnerNode);
402};
403
404/*!
405 * \brief RPCRunner that uses RPC call to measures the time cost of programs on remote devices.
406 * Or sometime we may need to use RPC even in local running to insulate the thread environment.
407 * (e.g. running CUDA programs)
408 */
409class RPCRunnerNode : public ProgramRunnerNode {
410 public:
411 /*! \brief The key of the device registered in the RPC tracker. */
412 String key;
413 /*! \brief The host address of the RPC Tracker. */
414 String host;
415 /*! \brief The port of the RPC Tracker. */
416 int port;
417 /*! \brief The priority of this run request, larger is more prior. */
418 int priority;
419 /*! \brief The number of tasks run in parallel. */
420 int n_parallel;
421
422 Array<MeasureResult> Run(const Array<MeasureInput>& inputs,
423 const Array<BuildResult>& build_results, int verbose) final;
424
425 static constexpr const char* _type_key = "auto_scheduler.RPCRunner";
426 TVM_DECLARE_FINAL_OBJECT_INFO(RPCRunnerNode, ProgramRunnerNode);
427};
428
429/*!
430 * \brief Managed reference to RPCRunnerNode.
431 * \sa RPCRunnerNode
432 */
433class RPCRunner : public ProgramRunner {
434 public:
435 /*!
436 * \brief The constructor. See the corresponding class in python/tvm/auto_scheduler/measure.py
437 * for more detailed parameter explanation.
438 * \param key The key of the device registered in the RPC tracker.
439 * \param host The host address of the RPC Tracker.
440 * \param port The port of RPC Tracker.
441 * \param priority The priority of this run request, larger is more prior.
442 * \param n_parallel The number of tasks run in parallel.
443 * \param timeout Timeout of a run.
444 * \param number The number of times to run the generated code for taking average.
445 * \param repeat The number of times to repeat the measurement.
446 * \param min_repeat_ms The minimum duration of one repeat in milliseconds.
447 * \param cooldown_interval The cool down interval between two measurements.
448 * \param enable_cpu_cache_flush Whether to flush cache on CPU between repeated measurements.
449 * \param device Which device to run on if multiple are available.
450 */
451 RPCRunner(const String& key, const String& host, int port, int priority, int n_parallel,
452 int timeout, int number, int repeat, int min_repeat_ms, double cooldown_interval,
453 bool enable_cpu_cache_flush, int device);
454
455 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RPCRunner, ProgramRunner, RPCRunnerNode);
456};
457
458/*!
459 * \brief Measurer that measures the time costs of tvm programs
460 * This class combines ProgramBuilder and ProgramRunner, and provides a simpler API */
461class ProgramMeasurerNode : public Object {
462 public:
463 /*! \brief Measured programs counter. */
464 int ct;
465 /*! \brief Continuous error counter. */
466 int error_ct;
467 /*! \brief Workload key to best flops map. */
468 std::unordered_map<std::string, double> best_flops;
469 /*! \brief Workload key to best state map. */
470 std::unordered_map<std::string, State> best_state;
471 /*! \brief Workload key to best state's count index map. */
472 std::unordered_map<std::string, int> best_ct;
473 /*! \brief The set of workloads that have at least one valid schedule */
474 std::unordered_set<std::string> has_valid;
475 /*! \brief The ProgramBuilder to build each program. */
476 ProgramBuilder builder;
477 /*! \brief The ProgramRunner to measure each program. */
478 ProgramRunner runner;
479 /*! \brief MeasureCallback to be called after each measure batch. */
480 Optional<Array<MeasureCallback>> callbacks;
481 /*! \brief Verbosity level. 0 for silent, 1 to output information during program measuring. */
482 int verbose;
483 /*! \brief The number of allowed maximum continuous error before forcely stopping the tuning */
484 int max_continuous_error;
485
486 /*! \brief Reset book keeping variables */
487 void Reset();
488
489 /*!
490 * \brief Do measurement.
491 * \param task The current SearchTask.
492 * \param policy The current SearchPolicy.
493 * \param inputs The inputs of measurement.
494 * \param batch_size Number of programs to be measured in one batch.
495 * \return results The results of measurement.
496 */
497 Array<MeasureResult> Measure(const SearchTask& task, const SearchPolicy& policy,
498 const Array<MeasureInput>& inputs, int batch_size = -1);
499 /*!
500 * \brief Do measurement silently.
501 * This API will not print the measure results to screen.
502 * \param task The current SearchTask.
503 * \param inputs The MeasureInputs.
504 * \param results A pointer to a MeasureResult Array, this is used as output.
505 */
506 void SilentMeasure(const SearchTask& task, const Array<MeasureInput>& inputs,
507 Array<MeasureResult>* results);
508
509 /*! \brief The default max continuous error setting. */
510 static const int DEFAULT_MAX_CONTINUOUS_ERROR = 150;
511
512 static constexpr const char* _type_key = "auto_scheduler.ProgramMeasurer";
513 TVM_DECLARE_FINAL_OBJECT_INFO(ProgramMeasurerNode, Object);
514};
515
516/*!
517 * \brief Managed reference to ProgramMeasurerNode.
518 * \sa ProgramMeasurerNode
519 */
520class ProgramMeasurer : public ObjectRef {
521 public:
522 /*!
523 * \brief The constructor.
524 * \param builder The ProgramBuilder to build programs.
525 * \param runner The ProgramRunner to measure programs.
526 * \param callbacks MeasureCallback to be called after each measurement batch.
527 * \param verbose Verbosity level. 0 for silent, 1 to output information during program
528 * measuring.
529 * \param max_continuous_error The number of allowed maximum continuous error before
530 * forcely stopping the tuning.
531 */
532 ProgramMeasurer(ProgramBuilder builder, ProgramRunner runner,
533 Optional<Array<MeasureCallback>> callbacks, int verbose,
534 int max_continuous_error = -1);
535
536 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ProgramMeasurer, ObjectRef, ProgramMeasurerNode);
537};
538
539} // namespace auto_scheduler
540} // namespace tvm
541
542#endif // TVM_AUTO_SCHEDULER_MEASURE_H_
543