1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/runtime/profiling.cc
22 * \brief Runtime profiling including timers.
23 */
24
25#include <dmlc/json.h>
26#include <tvm/runtime/c_backend_api.h>
27#include <tvm/runtime/data_type.h>
28#include <tvm/runtime/packed_func.h>
29#include <tvm/runtime/profiling.h>
30#include <tvm/runtime/threading_backend.h>
31
32#include <algorithm>
33#include <chrono>
34#include <iomanip>
35#include <iostream>
36#include <map>
37#include <numeric>
38#include <thread>
39
40namespace tvm {
41namespace runtime {
42
43class DefaultTimerNode : public TimerNode {
44 public:
45 virtual void Start() {
46 TVMSynchronize(device_.device_type, device_.device_id, nullptr);
47 start_ = std::chrono::high_resolution_clock::now();
48 }
49 virtual void Stop() {
50 TVMSynchronize(device_.device_type, device_.device_id, nullptr);
51 duration_ = std::chrono::high_resolution_clock::now() - start_;
52 }
53 virtual int64_t SyncAndGetElapsedNanos() { return duration_.count(); }
54 virtual ~DefaultTimerNode() {}
55
56 explicit DefaultTimerNode(Device dev) : device_(dev) {}
57 static constexpr const char* _type_key = "DefaultTimerNode";
58 TVM_DECLARE_FINAL_OBJECT_INFO(DefaultTimerNode, TimerNode);
59
60 private:
61 std::chrono::high_resolution_clock::time_point start_;
62 std::chrono::duration<int64_t, std::nano> duration_;
63 Device device_;
64};
65
66TVM_REGISTER_OBJECT_TYPE(DefaultTimerNode);
67TVM_REGISTER_OBJECT_TYPE(TimerNode);
68
69Timer DefaultTimer(Device dev) { return Timer(make_object<DefaultTimerNode>(dev)); }
70
71class CPUTimerNode : public TimerNode {
72 public:
73 virtual void Start() { start_ = std::chrono::high_resolution_clock::now(); }
74 virtual void Stop() { duration_ = std::chrono::high_resolution_clock::now() - start_; }
75 virtual int64_t SyncAndGetElapsedNanos() { return duration_.count(); }
76 virtual ~CPUTimerNode() {}
77
78 static constexpr const char* _type_key = "CPUTimerNode";
79 TVM_DECLARE_FINAL_OBJECT_INFO(CPUTimerNode, TimerNode);
80
81 private:
82 std::chrono::high_resolution_clock::time_point start_;
83 std::chrono::duration<int64_t, std::nano> duration_;
84};
85TVM_REGISTER_OBJECT_TYPE(CPUTimerNode);
86
87TVM_REGISTER_GLOBAL("profiling.timer.cpu").set_body_typed([](Device dev) {
88 return Timer(make_object<CPUTimerNode>());
89});
90
91// keep track of which timers are not defined but we have already warned about
92std::set<DLDeviceType> seen_devices;
93std::mutex seen_devices_lock;
94
95Timer Timer::Start(Device dev) {
96 auto f = Registry::Get(std::string("profiling.timer.") + DeviceName(dev.device_type));
97 if (f == nullptr) {
98 {
99 std::lock_guard<std::mutex> lock(seen_devices_lock);
100 if (seen_devices.find(dev.device_type) == seen_devices.end()) {
101 LOG(WARNING)
102 << "No timer implementation for " << DeviceName(dev.device_type)
103 << ", using default timer instead. It may be inaccurate or have extra overhead.";
104 seen_devices.insert(dev.device_type);
105 }
106 }
107 Timer t = DefaultTimer(dev);
108 t->Start();
109 return t;
110 } else {
111 Timer t = f->operator()(dev);
112 t->Start();
113 return t;
114 }
115}
116
117TVM_REGISTER_GLOBAL("profiling.start_timer").set_body_typed(Timer::Start);
118
119namespace profiling {
120
121Profiler::Profiler(std::vector<Device> devs, std::vector<MetricCollector> metric_collectors,
122 std::unordered_map<String, ObjectRef> configuration)
123 : devs_(devs), collectors_(metric_collectors), configuration_(configuration) {
124 is_running_ = false;
125 std::vector<DeviceWrapper> wrapped_devs;
126 for (auto dev : devs) {
127 wrapped_devs.push_back(DeviceWrapper(make_object<DeviceWrapperNode>(dev)));
128 }
129 for (auto& x : collectors_) {
130 x->Init(wrapped_devs);
131 }
132 // reset the thread pool so that PAPI eventset hooks are set in all threads.
133 threading::ResetThreadPool();
134
135 configuration_[String("Number of threads")] =
136 ObjectRef(make_object<CountNode>(threading::NumThreads()));
137}
138
139void Profiler::Start() {
140 is_running_ = true;
141 for (auto dev : devs_) {
142 StartCall("Total", dev, {});
143 }
144}
145
146void Profiler::StartCall(String name, Device dev,
147 std::unordered_map<std::string, ObjectRef> extra_metrics) {
148 std::vector<std::pair<MetricCollector, ObjectRef>> objs;
149 for (auto& collector : collectors_) {
150 ObjectRef obj = collector->Start(dev);
151 if (obj.defined()) {
152 objs.emplace_back(collector, obj);
153 }
154 }
155 in_flight_.push(CallFrame{dev, name, Timer::Start(dev), extra_metrics, objs});
156}
157
158void Profiler::StopCall(std::unordered_map<std::string, ObjectRef> extra_metrics) {
159 CallFrame cf = in_flight_.top();
160 cf.timer->Stop();
161 for (auto& p : extra_metrics) {
162 cf.extra_metrics[p.first] = p.second;
163 }
164 // collect the extra metrics from user defined collectors
165 for (const auto& obj : cf.extra_collectors) {
166 auto collector_metrics = obj.first->Stop(obj.second);
167 for (auto& p : collector_metrics) {
168 cf.extra_metrics[p.first] = p.second;
169 }
170 }
171 in_flight_.pop();
172 calls_.push_back(cf);
173}
174
175void Profiler::Stop() {
176 is_running_ = false;
177 for (size_t i = 0; i < devs_.size(); i++) {
178 StopCall();
179 }
180}
181
182std::vector<int64_t> ToShape(NDArray shape_tensor) {
183 std::vector<int64_t> shape;
184 auto rank = shape_tensor.Shape().size();
185 auto dtype = shape_tensor.DataType();
186
187 // For 0-rank shapes we need to allocate a single scalar.
188 if (rank == 0) {
189 return shape;
190 }
191
192 // Otherwise we should be rank-1, and we will extract the number of dimensions
193 // for the output vector.
194 ICHECK_EQ(rank, 1U) << "shape tensor should be a k-length vector, found " << rank;
195 int64_t ndim = shape_tensor.Shape().at(0);
196 shape.resize(ndim);
197
198 const DLTensor* dl_tensor = shape_tensor.operator->();
199 if (dtype.is_int() && dtype.bits() == 32 && dtype.lanes() == 1) {
200 int32_t* dims = reinterpret_cast<int32_t*>(dl_tensor->data);
201 shape.assign(dims, dims + ndim);
202 } else if (dtype.is_int() && dtype.bits() == 64 && dtype.lanes() == 1) {
203 int64_t* dims = reinterpret_cast<int64_t*>(dl_tensor->data);
204 shape.assign(dims, dims + ndim);
205 } else {
206 LOG(FATAL) << "invalid shape tensor datatype: " << dtype;
207 }
208
209 return shape;
210}
211
212String ShapeString(NDArray shape, DLDataType dtype) { return ShapeString(ToShape(shape), dtype); }
213
214String ShapeString(const std::vector<int64_t>& shape, DLDataType dtype) {
215 std::stringstream sizes;
216 sizes << dtype << "[";
217 for (size_t i = 0; i < shape.size(); i++) {
218 if (i != 0) {
219 sizes << ", ";
220 }
221 sizes << shape[i];
222 }
223 sizes << "]";
224 return String(sizes.str());
225}
226
227String ShapeString(const std::vector<NDArray>& shapes) {
228 std::stringstream sizes;
229 for (const NDArray& ary : shapes) {
230 if (sizes.tellp() > 0) {
231 sizes << ", ";
232 }
233 auto shape = ary.Shape();
234 sizes << ary.DataType() << "[";
235 for (size_t i = 0; i < shape.size(); i++) {
236 if (i != 0) {
237 sizes << ", ";
238 }
239 sizes << shape[i];
240 }
241 sizes << "]";
242 }
243 return String(sizes.str());
244}
245
246String ReportNode::AsCSV() const {
247 // get unique headers
248 std::set<std::string> unique_headers;
249
250 for (auto row : calls) {
251 for (auto p : row) {
252 unique_headers.insert(p.first);
253 }
254 }
255
256 std::vector<std::string> headers;
257 for (auto x : unique_headers) {
258 headers.push_back(x);
259 }
260
261 std::stringstream s;
262
263 for (size_t i = 0; i < headers.size(); i++) {
264 std::string header = headers[i];
265 s << header;
266 if (i < headers.size() - 1) {
267 s << ",";
268 }
269 }
270 s << std::endl;
271 for (auto row : calls) {
272 for (size_t i = 0; i < headers.size(); i++) {
273 std::string header = headers[i];
274 auto it = row.find(header);
275 if (it != row.end()) {
276 std::string val;
277 if ((*it).second.as<CountNode>()) {
278 s << (*it).second.as<CountNode>()->value;
279 } else if ((*it).second.as<DurationNode>()) {
280 s << (*it).second.as<DurationNode>()->microseconds;
281 } else if ((*it).second.as<PercentNode>()) {
282 s << (*it).second.as<PercentNode>()->percent;
283 } else if ((*it).second.as<RatioNode>()) {
284 s << (*it).second.as<RatioNode>()->ratio;
285 } else if ((*it).second.as<StringObj>()) {
286 s << "\"" << Downcast<String>((*it).second) << "\"";
287 }
288 }
289 if (i < headers.size() - 1) {
290 s << ",";
291 }
292 }
293 s << std::endl;
294 }
295 return s.str();
296}
297
298namespace {
299void metric_as_json(std::ostream& os, ObjectRef o) {
300 if (o.as<StringObj>()) {
301 os << "{\"string\":"
302 << "\"" << Downcast<String>(o) << "\""
303 << "}";
304 } else if (const CountNode* n = o.as<CountNode>()) {
305 os << "{\"count\":" << n->value << "}";
306 } else if (const DurationNode* n = o.as<DurationNode>()) {
307 os << "{\"microseconds\":" << std::setprecision(std::numeric_limits<double>::max_digits10)
308 << std::fixed << n->microseconds << "}";
309 } else if (const PercentNode* n = o.as<PercentNode>()) {
310 os << "{\"percent\":" << std::setprecision(std::numeric_limits<double>::max_digits10)
311 << std::fixed << n->percent << "}";
312 } else if (const RatioNode* n = o.as<RatioNode>()) {
313 os << "{\"ratio\":" << std::setprecision(std::numeric_limits<double>::max_digits10)
314 << std::fixed << n->ratio << "}";
315 } else {
316 LOG(FATAL) << "Unprintable type " << o->GetTypeKey();
317 }
318}
319} // namespace
320
321String ReportNode::AsJSON() const {
322 std::ostringstream s;
323 // DMLC's JSONWriter does not allow us to write a key value pair without
324 // implementing Write for the value. We want a specific write for the value,
325 // so we would have to implement a custom data structure for each type of
326 // value we want to print. Instead we construct the json by hand because it
327 // is easier.
328 s << "{";
329
330 s << "\"calls\":[";
331 for (size_t i = 0; i < calls.size(); i++) {
332 size_t j = 0;
333 s << "{";
334 for (const auto& kv : calls[i]) {
335 s << "\"" << kv.first << "\":";
336 metric_as_json(s, kv.second);
337 if (j < calls[i].size() - 1) {
338 s << ",";
339 }
340 j++;
341 }
342 s << "}";
343 if (i < calls.size() - 1) {
344 s << ",";
345 }
346 }
347 s << "],"; // end calls
348
349 s << "\"device_metrics\":{";
350 size_t i = 0;
351 for (const auto& dev_kv : device_metrics) {
352 size_t j = 0;
353 s << "\"" << dev_kv.first << "\":{";
354 for (const auto& metric_kv : dev_kv.second) {
355 s << "\"" << metric_kv.first << "\":";
356 metric_as_json(s, metric_kv.second);
357 if (j < dev_kv.second.size() - 1) {
358 s << ",";
359 }
360 j++;
361 }
362 s << "}";
363 if (i < device_metrics.size() - 1) {
364 s << ",";
365 }
366 i++;
367 }
368 s << "},"; // end device metrics
369
370 s << "\"configuration\":{";
371 size_t k = 0;
372 for (const auto& kv : configuration) {
373 s << "\"" << kv.first << "\":";
374 metric_as_json(s, kv.second);
375 if (k < configuration.size() - 1) {
376 s << ",";
377 }
378 k++;
379 }
380 s << "}"; // end configuration
381 s << "}";
382 return s.str();
383}
384
385// Aggregate a set of values for a metric. Computes sum for Duration, Count,
386// and Percent; average for Ratio; and assumes all Strings are the same. All
387// ObjectRefs in metrics must have the same type.
388ObjectRef AggregateMetric(const std::vector<ObjectRef>& metrics) {
389 ICHECK_GT(metrics.size(), 0) << "Must pass a non-zero number of metrics";
390 if (metrics[0].as<DurationNode>()) {
391 double sum = 0;
392 for (auto& metric : metrics) {
393 sum += metric.as<DurationNode>()->microseconds;
394 }
395 return ObjectRef(make_object<DurationNode>(sum));
396 } else if (metrics[0].as<CountNode>()) {
397 int64_t sum = 0;
398 for (auto& metric : metrics) {
399 sum += metric.as<CountNode>()->value;
400 }
401 return ObjectRef(make_object<CountNode>(sum));
402 } else if (metrics[0].as<PercentNode>()) {
403 double sum = 0;
404 for (auto& metric : metrics) {
405 sum += metric.as<PercentNode>()->percent;
406 }
407 return ObjectRef(make_object<PercentNode>(sum));
408 } else if (metrics[0].as<RatioNode>()) {
409 double sum = 0;
410 for (auto& metric : metrics) {
411 sum += metric.as<RatioNode>()->ratio;
412 }
413 return ObjectRef(make_object<RatioNode>(sum / metrics.size()));
414 } else if (metrics[0].as<StringObj>()) {
415 for (auto& m : metrics) {
416 if (Downcast<String>(metrics[0]) != Downcast<String>(m)) {
417 return ObjectRef(String(""));
418 }
419 }
420 // Assume all strings in metrics are the same.
421 return metrics[0];
422 } else {
423 LOG(FATAL) << "Can only aggregate metrics with types DurationNode, CountNode, "
424 "PercentNode, RatioNode, and StringObj, but got "
425 << metrics[0]->GetTypeKey();
426 return ObjectRef(); // To silence warnings
427 }
428}
429
430// Try and set the locale of the provided stringstream so that it will print
431// numbers with thousands separators. Sometimes users will have a misconfigured
432// system where an invalid locale is set, so we catch and ignore any locale
433// errors.
434static void set_locale_for_separators(std::stringstream& s) {
435 try {
436 // empty string indicates locale should be the user's default, see man 3 setlocale
437 s.imbue(std::locale(""));
438 } catch (std::runtime_error& e) {
439 }
440}
441
442static String print_metric(ObjectRef metric) {
443 std::string val;
444 if (metric.as<CountNode>()) {
445 std::stringstream s;
446 set_locale_for_separators(s);
447 s << std::fixed << metric.as<CountNode>()->value;
448 val = s.str();
449 } else if (metric.as<DurationNode>()) {
450 std::stringstream s;
451 set_locale_for_separators(s);
452 s << std::fixed << std::setprecision(2) << metric.as<DurationNode>()->microseconds;
453 val = s.str();
454 } else if (metric.as<PercentNode>()) {
455 std::stringstream s;
456 s << std::fixed << std::setprecision(2) << metric.as<PercentNode>()->percent;
457 val = s.str();
458 } else if (metric.as<RatioNode>()) {
459 std::stringstream s;
460 set_locale_for_separators(s);
461 s << std::setprecision(2) << metric.as<RatioNode>()->ratio;
462 val = s.str();
463 } else if (metric.as<StringObj>()) {
464 val = Downcast<String>(metric);
465 } else {
466 LOG(FATAL) << "Cannot print metric of type " << metric->GetTypeKey();
467 }
468 return val;
469}
470
471String ReportNode::AsTable(bool sort, bool aggregate, bool compute_col_sums) const {
472 // aggregate calls by op hash (or op name if hash is not set) + argument shapes
473 std::vector<Map<String, ObjectRef>> aggregated_calls;
474 if (aggregate) {
475 std::unordered_map<std::string, std::vector<size_t>> aggregates;
476 for (size_t i = 0; i < calls.size(); i++) {
477 auto& frame = calls[i];
478 auto it = frame.find("Hash");
479 std::string name = Downcast<String>(frame["Name"]);
480 if (it != frame.end()) {
481 name = Downcast<String>((*it).second);
482 }
483 if (frame.find("Argument Shapes") != frame.end()) {
484 name += Downcast<String>(frame["Argument Shapes"]);
485 }
486 if (frame.find("Device") != frame.end()) {
487 name += Downcast<String>(frame["Device"]);
488 }
489
490 if (aggregates.find(name) == aggregates.end()) {
491 aggregates[name] = {i};
492 } else {
493 aggregates[name].push_back(i);
494 }
495 }
496 for (const auto& p : aggregates) {
497 std::unordered_map<String, ObjectRef> aggregated;
498 std::unordered_set<std::string> metrics;
499 for (auto& call : calls) {
500 for (auto& metric : call) {
501 metrics.insert(metric.first);
502 }
503 }
504 for (const std::string& metric : metrics) {
505 std::vector<ObjectRef> per_call;
506 for (auto i : p.second) {
507 auto& call = calls[i];
508 auto it = std::find_if(call.begin(), call.end(),
509 [&metric](const std::pair<String, ObjectRef>& call_metric) {
510 return std::string(call_metric.first) == metric;
511 });
512 if (it != call.end()) {
513 per_call.push_back((*it).second);
514 }
515 }
516 if (per_call.size() > 0) {
517 aggregated[metric] = AggregateMetric(per_call);
518 }
519 }
520 aggregated_calls.push_back(aggregated);
521 }
522 } else {
523 for (auto call : calls) {
524 aggregated_calls.push_back(call);
525 }
526 }
527
528 // sort rows by duration
529 if (sort) {
530 std::sort(aggregated_calls.begin(), aggregated_calls.end(),
531 [&](const Map<String, ObjectRef>& a, const Map<String, ObjectRef>& b) {
532 return a.at("Duration (us)").as<DurationNode>()->microseconds >
533 b.at("Duration (us)").as<DurationNode>()->microseconds;
534 });
535 }
536
537 // compute columnwise sums
538 if (compute_col_sums) {
539 std::unordered_map<String, ObjectRef> col_sums;
540 for (auto call : aggregated_calls) {
541 for (auto p : call) {
542 if (p.second.as<CountNode>()) {
543 int64_t val = p.second.as<CountNode>()->value;
544 auto it = col_sums.find(p.first);
545 if (it != col_sums.end()) {
546 val += it->second.as<CountNode>()->value;
547 }
548 col_sums[p.first] = ObjectRef(make_object<CountNode>(val));
549 } else if (p.second.as<DurationNode>()) {
550 double val = p.second.as<DurationNode>()->microseconds;
551 auto it = col_sums.find(p.first);
552 if (it != col_sums.end()) {
553 val += it->second.as<DurationNode>()->microseconds;
554 }
555 col_sums[p.first] = ObjectRef(make_object<DurationNode>(val));
556 } else if (p.second.as<PercentNode>()) {
557 double val = p.second.as<PercentNode>()->percent;
558 auto it = col_sums.find(p.first);
559 if (it != col_sums.end()) {
560 val += it->second.as<PercentNode>()->percent;
561 }
562 col_sums[p.first] = ObjectRef(make_object<PercentNode>(val));
563 } else if (p.second.as<RatioNode>()) {
564 // It does not make sense to sum ratios
565 }
566 }
567 }
568 col_sums["Name"] = String("Sum");
569 aggregated_calls.push_back({{String("Name"), String("----------")}}); // separator
570 aggregated_calls.push_back(col_sums);
571 }
572
573 // per-device metrics
574 for (auto p : device_metrics) {
575 Map<String, ObjectRef> metrics = p.second;
576 metrics.Set("Name", String("Total"));
577 aggregated_calls.push_back(metrics);
578 }
579
580 // Table formatting
581 std::set<std::string> unique_headers;
582 for (auto row : aggregated_calls) {
583 for (auto p : row) {
584 unique_headers.insert(p.first);
585 }
586 }
587
588 // always include these headers in this order
589 std::vector<std::string> headers = {"Name", "Duration (us)", "Percent",
590 "Device", "Count", "Argument Shapes"};
591 for (auto header : unique_headers) {
592 if (std::find(headers.begin(), headers.end(), header) == headers.end()) {
593 headers.push_back(header);
594 }
595 }
596
597 // Switch layout from row major to column major so we can easily compute column widths.
598 std::vector<std::vector<std::string>> cols;
599 for (auto header : headers) {
600 cols.push_back({header});
601 }
602 for (auto row : aggregated_calls) {
603 for (size_t i = 0; i < headers.size(); i++) {
604 auto it = row.find(headers[i]);
605 if (it == row.end()) {
606 // fill empty data with empty strings
607 cols[i].push_back("");
608 } else {
609 cols[i].push_back(print_metric((*it).second));
610 }
611 }
612 }
613
614 std::vector<size_t> widths;
615 for (auto v : cols) {
616 size_t width = 0;
617 for (auto x : v) {
618 width = std::max(width, x.size());
619 }
620 widths.push_back(width);
621 }
622 size_t length = 0;
623 for (auto v : cols) {
624 length = std::max(length, v.size());
625 }
626
627 std::stringstream s;
628 for (size_t row = 0; row < length; row++) {
629 for (size_t col = 0; col < cols.size(); col++) {
630 // left align first column
631 if (col == 0) {
632 s << std::left;
633 } else {
634 s << std::right;
635 }
636 if (row < cols[col].size()) {
637 s << std::setw(widths[col]) << cols[col][row] << " ";
638 } else {
639 s << std::setw(widths[col]) << ""
640 << " ";
641 }
642 }
643 s << std::endl;
644 }
645
646 // Add configuration information. It will not be aligned with the columns.
647 s << std::endl << "Configuration" << std::endl << "-------------" << std::endl;
648 for (auto kv : configuration) {
649 s << kv.first << ": " << print_metric(kv.second) << std::endl;
650 }
651 return s.str();
652}
653
654std::string DeviceString(Device dev) {
655 return DeviceName(dev.device_type) + std::to_string(dev.device_id);
656}
657
658Report Profiler::Report() {
659 // sync all timers and normalize rows
660 std::vector<std::unordered_map<String, ObjectRef>> rows;
661 for (auto& cf : calls_) {
662 std::unordered_map<String, ObjectRef> row;
663 double us = cf.timer->SyncAndGetElapsedNanos() / 1e3;
664 row["Duration (us)"] = ObjectRef(make_object<DurationNode>(us));
665 row["Count"] = ObjectRef(make_object<CountNode>(1));
666 row["Name"] = cf.name;
667 row["Device"] = String(DeviceString(cf.dev));
668 for (auto p : cf.extra_metrics) {
669 row[p.first] = p.second;
670 }
671 rows.push_back(row);
672 }
673
674 // the last couple of call frames are the overall times
675 double overall_time_us = 0;
676 std::unordered_map<String, Map<String, ObjectRef>> device_metrics;
677 for (size_t i = 0; i < devs_.size(); i++) {
678 auto row = rows[rows.size() - 1];
679 rows.pop_back();
680 device_metrics[Downcast<String>(row["Device"])] = row;
681 overall_time_us =
682 std::max(overall_time_us, row["Duration (us)"].as<DurationNode>()->microseconds);
683 }
684
685 // Calculate percentages
686 for (auto& row : rows) {
687 row["Percent"] = ObjectRef(make_object<PercentNode>(
688 row["Duration (us)"].as<DurationNode>()->microseconds / overall_time_us * 100));
689 }
690
691 // convert to map
692 std::vector<Map<String, ObjectRef>> converted_rows;
693 for (const auto& row : rows) {
694 converted_rows.push_back(row);
695 }
696
697 return profiling::Report(converted_rows, device_metrics, configuration_);
698}
699
700Report::Report(Array<Map<String, ObjectRef>> calls,
701 Map<String, Map<String, ObjectRef>> device_metrics,
702 Map<String, ObjectRef> configuration) {
703 auto node = make_object<ReportNode>();
704 node->calls = std::move(calls);
705 node->device_metrics = std::move(device_metrics);
706 node->configuration = std::move(configuration);
707 data_ = std::move(node);
708}
709
710Map<String, ObjectRef> parse_metrics(dmlc::JSONReader* reader) {
711 reader->BeginObject();
712 std::string metric_name, metric_value_name;
713 Map<String, ObjectRef> metrics;
714 while (reader->NextObjectItem(&metric_name)) {
715 ObjectRef o;
716 reader->BeginObject();
717 reader->NextObjectItem(&metric_value_name);
718 if (metric_value_name == "microseconds") {
719 double microseconds;
720 reader->Read(&microseconds);
721 o = ObjectRef(make_object<DurationNode>(microseconds));
722 } else if (metric_value_name == "percent") {
723 double percent;
724 reader->Read(&percent);
725 o = ObjectRef(make_object<PercentNode>(percent));
726 } else if (metric_value_name == "count") {
727 int64_t count;
728 reader->Read(&count);
729 o = ObjectRef(make_object<CountNode>(count));
730 } else if (metric_value_name == "ratio") {
731 double ratio;
732 reader->Read(&ratio);
733 o = ObjectRef(make_object<RatioNode>(ratio));
734 } else if (metric_value_name == "string") {
735 std::string s;
736 reader->Read(&s);
737 o = String(s);
738 } else {
739 LOG(FATAL) << "Cannot parse metric of type " << metric_value_name
740 << " valid types are microseconds, percent, count.";
741 }
742 metrics.Set(metric_name, o);
743 // Necessary to make sure that the parser hits the end of the object.
744 ICHECK(!reader->NextObjectItem(&metric_value_name));
745 // EndObject does not exist, leaving this here for clarity
746 // reader.EndObject();
747 }
748 // reader.EndObject();
749 return metrics;
750}
751
752Report Report::FromJSON(String json) {
753 std::stringstream input(json.operator std::string());
754 dmlc::JSONReader reader(&input);
755 std::string key;
756 Array<Map<String, ObjectRef>> calls;
757 Map<String, Map<String, ObjectRef>> device_metrics;
758 Map<String, ObjectRef> configuration;
759
760 reader.BeginObject();
761 while (reader.NextObjectItem(&key)) {
762 if (key == "calls") {
763 reader.BeginArray();
764 while (reader.NextArrayItem()) {
765 calls.push_back(parse_metrics(&reader));
766 }
767 // reader.EndArray();
768 } else if (key == "device_metrics") {
769 reader.BeginObject();
770 std::string device_name;
771 while (reader.NextObjectItem(&device_name)) {
772 device_metrics.Set(device_name, parse_metrics(&reader));
773 }
774 // reader.EndObject();
775 } else if (key == "configuration") {
776 configuration = parse_metrics(&reader);
777 }
778 }
779
780 return Report(calls, device_metrics, configuration);
781}
782
783TVM_REGISTER_OBJECT_TYPE(DurationNode);
784TVM_REGISTER_OBJECT_TYPE(PercentNode);
785TVM_REGISTER_OBJECT_TYPE(CountNode);
786TVM_REGISTER_OBJECT_TYPE(RatioNode);
787TVM_REGISTER_OBJECT_TYPE(ReportNode);
788TVM_REGISTER_OBJECT_TYPE(DeviceWrapperNode);
789TVM_REGISTER_OBJECT_TYPE(MetricCollectorNode);
790
791TVM_REGISTER_GLOBAL("runtime.profiling.AsTable").set_body_method<Report>(&ReportNode::AsTable);
792TVM_REGISTER_GLOBAL("runtime.profiling.AsCSV").set_body_typed([](Report n) { return n->AsCSV(); });
793TVM_REGISTER_GLOBAL("runtime.profiling.AsJSON").set_body_typed([](Report n) {
794 return n->AsJSON();
795});
796TVM_REGISTER_GLOBAL("runtime.profiling.FromJSON").set_body_typed(Report::FromJSON);
797TVM_REGISTER_GLOBAL("runtime.profiling.DeviceWrapper").set_body_typed([](Device dev) {
798 return DeviceWrapper(dev);
799});
800
801PackedFunc ProfileFunction(Module mod, std::string func_name, int device_type, int device_id,
802 int warmup_iters, Array<MetricCollector> collectors) {
803 // Module::GetFunction is not const, so this lambda has to be mutable
804 return PackedFunc([=](TVMArgs args, TVMRetValue* ret) mutable {
805 PackedFunc f = mod.GetFunction(func_name);
806 CHECK(f.defined()) << "There is no function called \"" << func_name << "\" in the module";
807 Device dev{static_cast<DLDeviceType>(device_type), device_id};
808
809 // warmup
810 for (int i = 0; i < warmup_iters; i++) {
811 f.CallPacked(args, ret);
812 }
813
814 for (auto& collector : collectors) {
815 collector->Init({DeviceWrapper(dev)});
816 }
817 std::vector<Map<String, ObjectRef>> results;
818 results.reserve(collectors.size());
819 std::vector<std::pair<MetricCollector, ObjectRef>> collector_data;
820 collector_data.reserve(collectors.size());
821 for (auto& collector : collectors) {
822 ObjectRef o = collector->Start(dev);
823 // If not defined, then the collector cannot time this device.
824 if (o.defined()) {
825 collector_data.push_back({collector, o});
826 }
827 }
828
829 // TODO(tkonolige): repeated calls if the runtime is small?
830 f.CallPacked(args, ret);
831
832 for (auto& kv : collector_data) {
833 results.push_back(kv.first->Stop(kv.second));
834 }
835 Map<String, ObjectRef> combined_results;
836 for (auto m : results) {
837 for (auto p : m) {
838 // assume that there is no shared metric name between collectors
839 combined_results.Set(p.first, p.second);
840 }
841 }
842 *ret = combined_results;
843 });
844}
845
846TVM_REGISTER_GLOBAL("runtime.profiling.ProfileFunction")
847 .set_body_typed<PackedFunc(Module, String, int, int, int,
848 Array<MetricCollector>)>([](Module mod, String func_name,
849 int device_type, int device_id,
850 int warmup_iters,
851 Array<MetricCollector> collectors) {
852 if (mod->type_key() == std::string("rpc")) {
853 LOG(FATAL)
854 << "Profiling a module over RPC is not yet supported"; // because we can't send
855 // MetricCollectors over rpc.
856 throw;
857 } else {
858 return ProfileFunction(mod, func_name, device_type, device_id, warmup_iters, collectors);
859 }
860 });
861
862PackedFunc WrapTimeEvaluator(PackedFunc pf, Device dev, int number, int repeat, int min_repeat_ms,
863 int limit_zero_time_iterations, int cooldown_interval_ms,
864 int repeats_to_cooldown, PackedFunc f_preproc) {
865 ICHECK(pf != nullptr);
866
867 if (static_cast<int>(dev.device_type) == static_cast<int>(kDLMicroDev)) {
868 auto get_micro_time_evaluator = runtime::Registry::Get("micro._GetMicroTimeEvaluator");
869 ICHECK(get_micro_time_evaluator != nullptr) << "micro backend not enabled";
870 return (*get_micro_time_evaluator)(pf, dev, number, repeat);
871 }
872
873 auto ftimer = [pf, dev, number, repeat, min_repeat_ms, limit_zero_time_iterations,
874 cooldown_interval_ms, repeats_to_cooldown,
875 f_preproc](TVMArgs args, TVMRetValue* rv) mutable {
876 TVMRetValue temp;
877 std::ostringstream os;
878 // skip first time call, to activate lazy compilation components.
879 pf.CallPacked(args, &temp);
880
881 DeviceAPI::Get(dev)->StreamSync(dev, nullptr);
882
883 for (int i = 0; i < repeat; ++i) {
884 if (f_preproc != nullptr) {
885 f_preproc.CallPacked(args, &temp);
886 }
887 double duration_ms = 0.0;
888 int absolute_zero_times = 0;
889 do {
890 if (duration_ms > 0.0) {
891 const double golden_ratio = 1.618;
892 number = static_cast<int>(
893 std::max((min_repeat_ms / (duration_ms / number) + 1), number * golden_ratio));
894 }
895
896 // start timing
897 Timer t = Timer::Start(dev);
898 for (int j = 0; j < number; ++j) {
899 pf.CallPacked(args, &temp);
900 }
901 t->Stop();
902 int64_t t_nanos = t->SyncAndGetElapsedNanos();
903 if (t_nanos == 0) absolute_zero_times++;
904 duration_ms = t_nanos / 1e6;
905 } while (duration_ms < min_repeat_ms && absolute_zero_times < limit_zero_time_iterations);
906
907 double speed = duration_ms / 1e3 / number;
908 os.write(reinterpret_cast<char*>(&speed), sizeof(speed));
909
910 if (cooldown_interval_ms > 0 && (i % repeats_to_cooldown) == 0) {
911 std::this_thread::sleep_for(std::chrono::milliseconds(cooldown_interval_ms));
912 }
913 }
914
915 std::string blob = os.str();
916 TVMByteArray arr;
917 arr.size = blob.length();
918 arr.data = blob.data();
919 // return the time.
920 *rv = arr;
921 };
922 return PackedFunc(ftimer);
923}
924
925TVM_REGISTER_GLOBAL("runtime.profiling.Report")
926 .set_body_typed([](Array<Map<String, ObjectRef>> calls,
927 Map<String, Map<String, ObjectRef>> device_metrics,
928 Map<String, ObjectRef> configuration) {
929 return Report(calls, device_metrics, configuration);
930 });
931
932TVM_REGISTER_GLOBAL("runtime.profiling.Count").set_body_typed([](int64_t count) {
933 return ObjectRef(make_object<CountNode>(count));
934});
935
936TVM_REGISTER_GLOBAL("runtime.profiling.Percent").set_body_typed([](double percent) {
937 return ObjectRef(make_object<PercentNode>(percent));
938});
939
940TVM_REGISTER_GLOBAL("runtime.profiling.Duration").set_body_typed([](double duration) {
941 return ObjectRef(make_object<DurationNode>(duration));
942});
943
944TVM_REGISTER_GLOBAL("runtime.profiling.Ratio").set_body_typed([](double ratio) {
945 return ObjectRef(make_object<RatioNode>(ratio));
946});
947
948} // namespace profiling
949} // namespace runtime
950} // namespace tvm
951