1#include "kernel_profiler.h"
2
3#include "taichi/system/timer.h"
4#include "taichi/rhi/cuda/cuda_driver.h"
5#include "taichi/rhi/cuda/cuda_profiler.h"
6#include "taichi/system/timeline.h"
7
8namespace taichi::lang {
9
10void KernelProfileStatisticalResult::insert_record(double t) {
11 if (counter == 0) {
12 min = t;
13 max = t;
14 }
15 counter++;
16 min = std::min(min, t);
17 max = std::max(max, t);
18 total += t;
19}
20
21bool KernelProfileStatisticalResult::operator<(
22 const KernelProfileStatisticalResult &o) const {
23 return total > o.total;
24}
25
26void KernelProfilerBase::profiler_start(KernelProfilerBase *profiler,
27 const char *kernel_name) {
28 TI_ASSERT(profiler);
29 profiler->start(std::string(kernel_name));
30}
31
32void KernelProfilerBase::profiler_stop(KernelProfilerBase *profiler) {
33 TI_ASSERT(profiler);
34 profiler->stop();
35}
36
37// TODO : deprecated
38void KernelProfilerBase::query(const std::string &kernel_name,
39 int &counter,
40 double &min,
41 double &max,
42 double &avg) {
43 sync();
44 std::regex name_regex(kernel_name + "(.*)");
45 for (auto &rec : statistical_results_) {
46 if (std::regex_match(rec.name, name_regex)) {
47 if (counter == 0) {
48 counter = rec.counter;
49 min = rec.min;
50 max = rec.max;
51 avg = rec.total / rec.counter;
52 } else if (counter == rec.counter) {
53 min += rec.min;
54 max += rec.max;
55 avg += rec.total / rec.counter;
56 } else {
57 TI_WARN("{}.counter({}) != {}.counter({}).", kernel_name, counter,
58 rec.name, rec.counter);
59 }
60 }
61 }
62}
63
64double KernelProfilerBase::get_total_time() const {
65 return total_time_ms_ / 1000.0;
66}
67
68void KernelProfilerBase::insert_record(const std::string &kernel_name,
69 double duration_ms) {
70 // Trace record
71 KernelProfileTracedRecord record;
72 record.name = kernel_name;
73 record.kernel_elapsed_time_in_ms = duration_ms;
74 traced_records_.push_back(record);
75 // Count record
76 auto it = std::find_if(
77 statistical_results_.begin(), statistical_results_.end(),
78 [&](KernelProfileStatisticalResult &r) { return r.name == record.name; });
79 if (it == statistical_results_.end()) {
80 statistical_results_.emplace_back(record.name);
81 it = std::prev(statistical_results_.end());
82 }
83 it->insert_record(duration_ms);
84 total_time_ms_ += duration_ms;
85}
86
87namespace {
88// A simple profiler that uses Time::get_time()
89class DefaultProfiler : public KernelProfilerBase {
90 public:
91 void sync() override {
92 }
93
94 void update() override {
95 }
96
97 void clear() override {
98 // sync(); //decoupled: trigger from the foront end
99 total_time_ms_ = 0;
100 traced_records_.clear();
101 statistical_results_.clear();
102 }
103
104 void start(const std::string &kernel_name) override {
105 start_t_ = Time::get_time();
106 event_name_ = kernel_name;
107 }
108
109 void stop() override {
110 auto t = Time::get_time() - start_t_;
111 auto ms = t * 1000.0;
112 // trace record
113 KernelProfileTracedRecord record;
114 record.name = event_name_;
115 record.kernel_elapsed_time_in_ms = ms;
116 traced_records_.push_back(record);
117 // count record
118 auto it =
119 std::find_if(statistical_results_.begin(), statistical_results_.end(),
120 [&](KernelProfileStatisticalResult &r) {
121 return r.name == event_name_;
122 });
123 if (it == statistical_results_.end()) {
124 statistical_results_.emplace_back(event_name_);
125 it = std::prev(statistical_results_.end());
126 }
127 it->insert_record(ms);
128 total_time_ms_ += ms;
129 }
130
131 private:
132 double start_t_;
133 std::string event_name_;
134};
135
136} // namespace
137
138std::unique_ptr<KernelProfilerBase> make_profiler(Arch arch, bool enable) {
139 if (!enable)
140 return nullptr;
141 if (arch == Arch::cuda) {
142#if defined(TI_WITH_CUDA)
143 return std::make_unique<KernelProfilerCUDA>(enable);
144#else
145 TI_NOT_IMPLEMENTED;
146#endif
147 } else {
148 return std::make_unique<DefaultProfiler>();
149 }
150}
151
152} // namespace taichi::lang
153