1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file include/tvm/runtime/profiling.h
22 * \brief Runtime profiling including timers.
23 */
24#ifndef TVM_RUNTIME_PROFILING_H_
25#define TVM_RUNTIME_PROFILING_H_
26
27#include <tvm/runtime/c_runtime_api.h>
28#include <tvm/runtime/container/map.h>
29#include <tvm/runtime/device_api.h>
30#include <tvm/runtime/object.h>
31#include <tvm/runtime/packed_func.h>
32#include <tvm/runtime/registry.h>
33
34#include <stack>
35#include <string>
36#include <unordered_map>
37#include <utility>
38#include <vector>
39
40namespace tvm {
41
42namespace runtime {
43
44/*! \brief Base class for all implementations.
45 *
46 * New implementations of this interface should make sure that `Start` and `Stop`
47 * are as lightweight as possible. Expensive state synchronization should be
48 * done in `SyncAndGetElapsedNanos`.
49 */
50class TimerNode : public Object {
51 public:
52 /*! \brief Start the timer.
53 *
54 * Note: this function should only be called once per object.
55 */
56 virtual void Start() = 0;
57 /*! \brief Stop the timer.
58 *
59 * Note: this function should only be called once per object.
60 */
61 virtual void Stop() = 0;
62 /*! \brief Synchronize timer state and return elapsed time between `Start` and `Stop`.
63 * \return The time in nanoseconds between `Start` and `Stop`.
64 *
65 * This function is necessary because we want to avoid timing the overhead of
66 * doing timing. When using multiple timers, it is recommended to stop all of
67 * them before calling `SyncAndGetElapsedNanos` on any of them.
68 *
69 * Note: this function should be only called once per object. It may incur
70 * a large synchronization overhead (for example, with GPUs).
71 */
72 virtual int64_t SyncAndGetElapsedNanos() = 0;
73
74 virtual ~TimerNode() {}
75
76 static constexpr const char* _type_key = "TimerNode";
77 TVM_DECLARE_BASE_OBJECT_INFO(TimerNode, Object);
78};
79
80/*! \brief Timer for a specific device.
81 *
82 * This is a managed reference to a TimerNode.
83 *
84 * \sa TimerNode
85 */
86class Timer : public ObjectRef {
87 public:
88 /*!
89 * \brief Get a device specific timer.
90 * \param dev The device to time.
91 * \return A `Timer` that has already been started.
92 *
93 * Use this function to time runtime of arbitrary regions of code on a specific
94 * device. The code that you want to time should be running on the device
95 * otherwise the timer will not return correct results. This is a lower level
96 * interface than TimeEvaluator and only runs the timed code once
97 * (TimeEvaluator runs the code multiple times).
98 *
99 * A default timer is used if a device specific one does not exist. This
100 * timer performs synchronization between the device and CPU, which can lead
101 * to overhead in the reported results.
102 *
103 * Example usage:
104 * \code{.cpp}
105 * Timer t = Timer::Start(Device::cpu());
106 * my_long_running_function();
107 * t->Stop();
108 * ... // some more computation
109 * int64_t nanosecs = t->SyncAndGetElapsedNanos() // elapsed time in nanoseconds
110 * \endcode
111 *
112 * To add a new device-specific timer, register a new function
113 * "profiler.timer.my_device" (where `my_device` is the `DeviceName` of your
114 * device). This function should accept a `Device` and return a new `Timer`
115 * that has already been started.
116 *
117 * For example, this is how the CPU timer is implemented:
118 * \code{.cpp}
119 * class CPUTimerNode : public TimerNode {
120 * public:
121 * virtual void Start() { start_ = std::chrono::high_resolution_clock::now(); }
122 * virtual void Stop() { duration_ = std::chrono::high_resolution_clock::now() - start_; }
123 * virtual int64_t SyncAndGetElapsedNanos() { return duration_.count(); }
124 * virtual ~CPUTimerNode() {}
125 *
126 * static constexpr const char* _type_key = "CPUTimerNode";
127 * TVM_DECLARE_FINAL_OBJECT_INFO(CPUTimerNode, TimerNode);
128 *
129 * private:
130 * std::chrono::high_resolution_clock::time_point start_;
131 * std::chrono::duration<int64_t, std::nano> duration_;
132 * };
133 * TVM_REGISTER_OBJECT_TYPE(CPUTimerNode);
134 *
135 * TVM_REGISTER_GLOBAL("profiling.timer.cpu").set_body_typed([](Device dev) {
136 * return Timer(make_object<CPUTimerNode>());
137 * });
138 * \endcode
139 */
140 static TVM_DLL Timer Start(Device dev);
141
142 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Timer, ObjectRef, TimerNode);
143};
144
145/*!
146 * \brief Default timer if one does not exist for the device.
147 * \param dev The device to time on.
148 *
149 * Note that this timer performs synchronization between the device and CPU,
150 * which can lead to overhead in the reported results.
151 */
152Timer DefaultTimer(Device dev);
153
154namespace profiling {
155/*! \brief Wrapper for `Device` because `Device` is not passable across the
156 * PackedFunc interface.
157 */
158struct DeviceWrapperNode : public Object {
159 /*! The device */
160 Device device;
161
162 /*! Constructor */
163 explicit DeviceWrapperNode(Device device) : device(device) {}
164
165 static constexpr const char* _type_key = "runtime.profiling.DeviceWrapper";
166 TVM_DECLARE_BASE_OBJECT_INFO(DeviceWrapperNode, Object);
167};
168
169/*! \brief Wrapper for `Device`. */
170class DeviceWrapper : public ObjectRef {
171 public:
172 explicit DeviceWrapper(Device dev) { data_ = make_object<DeviceWrapperNode>(dev); }
173 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(DeviceWrapper, ObjectRef, DeviceWrapperNode);
174};
175
176/*! \brief Data collected from a profiling run. Includes per-call metrics and per-device metrics.
177 */
178class ReportNode : public Object {
179 public:
180 /*! \brief A list of function calls and the metrics recorded for that call.
181 *
182 * Each element is a mapping from metric name to value. Some metrics that
183 * appear in every call are "Name" (the function name), "Argument Shapes",
184 * and "Duration (us)". Values are one of `String`, `PercentNode`,
185 * `DurationNode`, or `CountNode`.
186 */
187 Array<Map<String, ObjectRef>> calls;
188 /*! \brief Metrics collected for the entire run of the model on a per-device basis.
189 *
190 * `device_metrics` is indexed by device name then metric.
191 *
192 * These metrics may be larger than the sum of the same metric in `calls`
193 * because these metrics include the overhead of the executor.
194 */
195 Map<String, Map<String, ObjectRef>> device_metrics;
196 /*! Configuration used for this profiling run. Includes number of threads, executor.
197 *
198 * Values must be an object type that can be used with device_metrics.
199 */
200 Map<String, ObjectRef> configuration;
201 /*! \brief Output `calls` in CSV format.
202 *
203 * Note that this does not include `device_metrics`, it only includes per-call metrics.
204 */
205 String AsCSV() const;
206 /*! \brief Create a human readable table of profiling metrics.
207 *
208 * \param aggregate Whether or not to join multiple calls to the
209 * same op into a single line.
210 *
211 * \param sort Whether or not to sort call frames by descending
212 * duration. If false and if `aggregate` is false, frames will
213 * be sorted by order of appearance in the program. Order is
214 * undefined if `sort` is false and `aggregate` is true.
215 *
216 * \param compute_col_sums Whether or not to include sum totals for
217 * the Count, Duation, and Percent columns.
218 *
219 */
220 String AsTable(bool sort = true, bool aggregate = true, bool compute_col_sums = true) const;
221 /*! \brief Convert this report to JSON.
222 *
223 * Output JSON will be of this format:
224 * \code
225 * {
226 * "calls": [
227 * {
228 * "Duration (us)": {
229 * "microseconds": 12.3
230 * },
231 * "Name": "fused_dense",
232 * "Count": {
233 * "count": 1
234 * },
235 * "Percent": {
236 * "percent": 10.3
237 * }
238 * }
239 * ],
240 * "device_metrics": {
241 * "cpu": {
242 * "Duration (us)": {
243 * "microseconds": 334.2
244 * },
245 * "Percent": {
246 * "percent": 100
247 * }
248 * }
249 * }
250 * }
251 * \endcode
252 */
253 String AsJSON() const;
254
255 static constexpr const char* _type_key = "runtime.profiling.Report";
256 TVM_DECLARE_FINAL_OBJECT_INFO(ReportNode, Object);
257};
258
259class Report : public ObjectRef {
260 public:
261 /*! Construct a Report from a set of calls (with associated metrics) and per-device metrics.
262 * \param calls Function calls and associated metrics.
263 * \param device_metrics Per-device metrics for overall execution.
264 * \param configuration Configuration data specific to this profiling run.
265 */
266 explicit Report(Array<Map<String, ObjectRef>> calls,
267 Map<String, Map<String, ObjectRef>> device_metrics,
268 Map<String, ObjectRef> configuration);
269
270 /*! Deserialize a Report from a JSON object. Needed for sending the report over RPC.
271 * \param json Serialized json report from `ReportNode::AsJSON`.
272 * \returns A Report.
273 */
274 static Report FromJSON(String json);
275 TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Report, ObjectRef, ReportNode);
276};
277
278/*! \brief Interface for user defined profiling metric collection.
279 *
280 * Users can register their own collector by registering a packed function with
281 * the name "runtime.profiling.metrics.my_collector_name" where
282 * "my_collector_name" is the name of their collector. This function should
283 * take an Array of Device as input which contains the devices the collector
284 * will be run on.
285 *
286 * `MetricCollectorNode`s will be called in the following fashion.
287 * \code
288 * MetricCollector mc;
289 * for (auto op : model) {
290 * auto o = mc.Start();
291 * op();
292 * auto metrics = mc.Stop(o); // metrics are added the profiling report
293 * }
294 * \endcode
295 */
296class MetricCollectorNode : public Object {
297 public:
298 /*! \brief Initialization call. Called before profiling has started. Any
299 * expensive precomputation should happen here.
300 * \param devs The list of devices this collector will be run on.
301 */
302 virtual void Init(Array<DeviceWrapper> devs) = 0;
303 /*! \brief Start colling metrics for a function call.
304 * \param dev The device the call will be run on.
305 * \returns An object used to maintain state of the metric collection. This
306 * object will be passed to the corresponding `Stop` call. If the device is
307 * not supported, this function will return a nullptr ObjectRef.
308 */
309 virtual ObjectRef Start(Device dev) = 0;
310 /*! \brief Stop collecting metrics.
311 * \param obj The object created by the corresponding `Start` call.
312 * \returns A set of metric names and the associated values. Values must be
313 * one of DurationNode, PercentNode, CountNode, or StringObj.
314 */
315 virtual Map<String, ObjectRef> Stop(ObjectRef obj) = 0;
316
317 virtual ~MetricCollectorNode() {}
318
319 static constexpr const char* _type_key = "runtime.profiling.MetricCollector";
320 TVM_DECLARE_BASE_OBJECT_INFO(MetricCollectorNode, Object);
321};
322
323/*! \brief Wrapper for `MetricCollectorNode`. */
324class MetricCollector : public ObjectRef {
325 public:
326 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MetricCollector, ObjectRef, MetricCollectorNode);
327};
328
329/*! Information about a single function or operator call. */
330struct CallFrame {
331 /*! Device on which the call was made */
332 Device dev;
333 /*! Name of the function or op */
334 String name;
335 /*! Runtime of the function or op */
336 Timer timer;
337 /*! Extra performance metrics */
338 std::unordered_map<std::string, ObjectRef> extra_metrics;
339 /*! User defined metric collectors. Each pair is the MetricCollector and its
340 * associated data (returned from MetricCollector.Start).
341 */
342 std::vector<std::pair<MetricCollector, ObjectRef>> extra_collectors;
343};
344
345/*! Runtime profiler for function and/or operator calls. Used in the graph
346 * runtime and VM to provide profiling information for all operators.
347 *
348 * Example usage:
349 * \code{.cpp}
350 * Device cpu, gpu;
351 * Profiler prof({cpu, gpu});
352 * my_gpu_kernel(); // do a warmup iteration
353 * prof.Start();
354 * prof.StartCall("my_gpu_kernel", gpu);
355 * my_gpu_kernel();
356 * prof.StopCall();
357 * prof.StartCall("my_cpu_function", cpu);
358 * my_cpu_function();
359 * prof.StopCall();
360 * prof.Stop();
361 * std::cout << prof.Report << std::endl; // print profiling report
362 * \endcode
363 */
364class Profiler {
365 public:
366 /*! Constructor.
367 *
368 * The profiler should be constructed before you do any warmup iterations.
369 *
370 * \note
371 * Calling this constructor will reset the TVM threadpool. It is necessary in
372 * order to install thread handlers required by certain collectors.
373 *
374 * \param devs The list of devices the profiler will be running on. Should
375 * include all devices used by profiled operators.
376 * \param metric_collectors Additional `MetricCollector`s to use with this profiler.
377 * \param configuration Additional configuration data to add to the outputted profiling report.
378 */
379 explicit Profiler(std::vector<Device> devs, std::vector<MetricCollector> metric_collectors,
380 std::unordered_map<String, ObjectRef> configuration = {});
381 /*! \brief Start the profiler.
382 *
383 * This function should only be called once per object.
384 */
385 void Start();
386 /*! \brief Stop the profiler.
387 *
388 * This function should only be called once per object after start has been called.
389 */
390 void Stop();
391 /*! \brief Start a function call.
392 * \param name The name of the function being called.
393 * \param dev The device on which the function is running.
394 * \param extra_metrics Optional additional profiling information to add to
395 * the frame (input sizes, allocations).
396 *
397 * `StartCall` may be nested, but each `StartCall` needs a matching
398 * `StopCall`. Function calls are stopped in LIFO order, so calls to
399 * `StartCall` and `StopCall` must be nested properly.
400 */
401 void StartCall(String name, Device dev,
402 std::unordered_map<std::string, ObjectRef> extra_metrics = {});
403 /*! \brief Stop the last `StartCall`.
404 * \param extra_metrics Optional additional profiling information to add to
405 * the frame (input sizes, allocations).
406 */
407 void StopCall(std::unordered_map<std::string, ObjectRef> extra_metrics = {});
408 /*! \brief A report of total runtime between `Start` and `Stop` as
409 * well as individual statistics for each `StartCall`-`StopCall` pair.
410 * \returns A `Report` that can either be formatted as CSV (with `.AsCSV`)
411 * or as a human readable table (with `.AsTable`).
412 */
413 profiling::Report Report();
414 /*! \brief Check if the profiler is currently running.
415 * \returns Whether or not the profiler is running.
416 */
417 bool IsRunning() const { return is_running_; }
418
419 private:
420 std::vector<Device> devs_;
421 bool is_running_{false};
422 std::vector<CallFrame> calls_;
423 std::stack<CallFrame> in_flight_;
424 std::vector<MetricCollector> collectors_;
425 std::unordered_map<String, ObjectRef> configuration_;
426};
427
428/* \brief A duration in time. */
429class DurationNode : public Object {
430 public:
431 /* The duration as a floating point number of microseconds. */
432 double microseconds;
433
434 /* \brief Construct a new duration.
435 * \param a The duration in microseconds.
436 */
437 explicit DurationNode(double a) : microseconds(a) {}
438
439 static constexpr const char* _type_key = "runtime.profiling.Duration";
440 TVM_DECLARE_FINAL_OBJECT_INFO(DurationNode, Object);
441};
442
443/* A percentage of something */
444class PercentNode : public Object {
445 public:
446 /* The percent as a floating point value out of 100%. i.e. if `percent` is 10 then we have 10%. */
447 double percent;
448
449 /* \brief Construct a new percentage.
450 * \param a The percentage out of 100.
451 */
452 explicit PercentNode(double a) : percent(a) {}
453
454 static constexpr const char* _type_key = "runtime.profiling.Percent";
455 TVM_DECLARE_FINAL_OBJECT_INFO(PercentNode, Object);
456};
457
458/* A count of something */
459class CountNode : public Object {
460 public:
461 /* The actual count */
462 int64_t value;
463
464 /* \brief Construct a new count.
465 * \param a The count.
466 */
467 explicit CountNode(int64_t a) : value(a) {}
468
469 static constexpr const char* _type_key = "runtime.profiling.Count";
470 TVM_DECLARE_FINAL_OBJECT_INFO(CountNode, Object);
471};
472
473/* \brief A ratio of two things. */
474class RatioNode : public Object {
475 public:
476 /* The ratio as a double precision floating point number. */
477 double ratio;
478
479 /* \brief Construct a new ratio.
480 * \param a The ratio.
481 */
482 explicit RatioNode(double a) : ratio(a) {}
483
484 static constexpr const char* _type_key = "runtime.profiling.Ratio";
485 TVM_DECLARE_FINAL_OBJECT_INFO(RatioNode, Object);
486};
487
488/*! \brief String representation of an array of NDArray shapes
489 * \param shapes Array of NDArrays to get the shapes of.
490 * \return A textual representation of the shapes. For example: `float32[2], int64[1, 2]`.
491 */
492String ShapeString(const std::vector<NDArray>& shapes);
493/*! \brief String representation of shape encoded as an NDArray
494 * \param shape NDArray containing the shape.
495 * \param dtype The dtype of the shape.
496 * \return A textual representation of the shape. For example: `float32[2]`.
497 */
498String ShapeString(NDArray shape, DLDataType dtype);
499/*! \brief String representation of a shape encoded as a vector
500 * \param shape Shape as a vector of integers.
501 * \param dtype The dtype of the shape.
502 * \return A textual representation of the shape. For example: `float32[2]`.
503 */
504String ShapeString(const std::vector<int64_t>& shape, DLDataType dtype);
505
506/*! \brief Collect performance information of a function execution. Usually
507 * used with a compiled PrimFunc (via tvm.build).
508 *
509 * This information can include performance counters like cache hits and FLOPs
510 * that are useful in debugging performance issues of individual PrimFuncs.
511 * Different metrics can be collected depending on which MetricCollector is
512 * used.
513 *
514 * Example usage:
515 * \code{.cpp}
516 * // Use PAPI to measure the number of floating point operations.
517 * PackedFunc profiler = ProfileModule(
518 * mod, "main", kDLCPU, 0, {CreatePAPIMetricCollector({{kDLCPU, 0}, {"PAPI_FP_OPS"}})});
519 * Report r = profiler(arg1, arg2, arg);
520 * std::cout << r << std::endl;
521 * \endcode
522 *
523 * \param mod Module to profile. Usually a PrimFunc that has been compiled to machine code.
524 * \param func_name Name of function to run in the module.
525 * \param device_type Device type to run on. Profiling will include performance
526 * metrics specific to this device type.
527 * \param device_id Id of device to run on.
528 * \param warmup_iters Number of iterations of the function to run before collecting
529 * performance information. Recommend to set this larger
530 * than 0 so that cache effects are consistent.
531 * \param collectors List of different
532 * ways to collect metrics. See MetricCollector.
533 * \returns A PackedFunc which takes the same arguments as the `mod[func_name]`
534 * and returns performance metrics as a `Map<String, ObjectRef>` where
535 * values can be `CountNode`, `DurationNode`, `PercentNode`.
536 */
537PackedFunc ProfileFunction(Module mod, std::string func_name, int device_type, int device_id,
538 int warmup_iters, Array<MetricCollector> collectors);
539
540/*!
541 * \brief Wrap a timer function to measure the time cost of a given packed function.
542 *
543 * Approximate implementation:
544 * \code{.py}
545 * f() // warmup
546 * for i in range(repeat)
547 * f_preproc()
548 * while True:
549 * start = time()
550 * for j in range(number):
551 * f()
552 * duration_ms = time() - start
553 * if duration_ms >= min_repeat_ms:
554 * break
555 * else:
556 * number = (min_repeat_ms / (duration_ms / number) + 1
557 * if cooldown_interval_ms and i % repeats_to_cooldown == 0:
558 * sleep(cooldown_interval_ms)
559 * \endcode
560 *
561 * \param f The function argument.
562 * \param dev The device.
563 * \param number The number of times to run this function for taking average.
564 * We call these runs as one `repeat` of measurement.
565 * \param repeat The number of times to repeat the measurement.
566 * In total, the function will be invoked (1 + number x repeat) times,
567 * where the first one is warm up and will be discarded.
568 * The returned result contains `repeat` costs,
569 * each of which is an average of `number` costs.
570 * \param min_repeat_ms The minimum duration of one `repeat` in milliseconds.
571 * By default, one `repeat` contains `number` runs. If this parameter is set,
572 * the parameters `number` will be dynamically adjusted to meet the
573 * minimum duration requirement of one `repeat`.
574 * i.e., When the run time of one `repeat` falls below this time,
575 * the `number` parameter will be automatically increased.
576 * \param limit_zero_time_iterations The maximum number of repeats when
577 * measured time is equal to 0. It helps to avoid hanging during measurements.
578 * \param cooldown_interval_ms The cooldown interval in milliseconds between the number of repeats
579 * defined by `repeats_to_cooldown`.
580 * \param repeats_to_cooldown The number of repeats before the
581 * cooldown is activated.
582 * \param f_preproc The function to be executed before we execute time
583 * evaluator.
584 * \return f_timer A timer function.
585 */
586PackedFunc WrapTimeEvaluator(PackedFunc f, Device dev, int number, int repeat, int min_repeat_ms,
587 int limit_zero_time_iterations, int cooldown_interval_ms,
588 int repeats_to_cooldown, PackedFunc f_preproc = nullptr);
589
590} // namespace profiling
591} // namespace runtime
592} // namespace tvm
593
594#endif // TVM_RUNTIME_PROFILING_H_
595