1#include "taichi/system/profiler.h"
2#include "spdlog/fmt/bundled/color.h"
3
4namespace taichi {
5
6// A profiler's records form a tree structure
7struct ProfilerRecordNode {
8 std::vector<std::unique_ptr<ProfilerRecordNode>> childs;
9 ProfilerRecordNode *parent;
10 std::string name;
11 float64 total_time;
12 // Time per element
13 bool account_tpe;
14 uint64 total_elements;
15 int64 num_samples;
16
17 ProfilerRecordNode(const std::string &name, ProfilerRecordNode *parent) {
18 this->name = name;
19 this->parent = parent;
20 this->total_time = 0.0_f64;
21 this->num_samples = 0ll;
22 this->total_elements = 0ll;
23 this->account_tpe = false;
24 }
25
26 void insert_sample(float64 sample) {
27 num_samples += 1;
28 total_time += sample;
29 }
30
31 void insert_sample(float64 sample, uint64 elements) {
32 account_tpe = true;
33 num_samples += 1;
34 total_time += sample;
35 total_elements += elements;
36 }
37
38 float64 get_averaged() const {
39 return total_time / (float64)std::max(num_samples, int64(1));
40 }
41
42 float64 get_averaged_tpe() const {
43 TI_ASSERT(account_tpe);
44 return total_time / (float64)total_elements;
45 }
46
47 ProfilerRecordNode *get_child(const std::string &name) {
48 for (auto &ch : childs) {
49 if (ch->name == name) {
50 return ch.get();
51 }
52 }
53 childs.push_back(std::make_unique<ProfilerRecordNode>(name, this));
54 return childs.back().get();
55 }
56};
57
58class ProfilerRecords {
59 public:
60 std::unique_ptr<ProfilerRecordNode> root;
61 ProfilerRecordNode *current_node;
62 int current_depth;
63 bool enabled;
64
65 explicit ProfilerRecords(const std::string &name) {
66 root = std::make_unique<ProfilerRecordNode>(
67 fmt::format("[Profiler {}]", name), nullptr);
68 current_node = root.get();
69 current_depth = 0; // depth(root) = 0
70 enabled = true;
71 }
72
73 void clear() {
74 root->childs.clear();
75 current_node = root.get();
76 current_depth = 0;
77 enabled = true;
78 }
79
80 void print(ProfilerRecordNode *node, int depth);
81
82 void print() {
83 fmt::print(fg(fmt::color::cyan), std::string(80, '>') + "\n");
84 print(root.get(), 0);
85 fmt::print(fg(fmt::color::cyan), std::string(80, '>') + "\n");
86 }
87
88 void insert_sample(float64 time) {
89 if (!enabled)
90 return;
91 current_node->insert_sample(time);
92 }
93
94 void insert_sample(float64 time, uint64 tpe) {
95 if (!enabled)
96 return;
97 current_node->insert_sample(time, tpe);
98 }
99
100 void push(const std::string name) {
101 if (!enabled)
102 return;
103 current_node = current_node->get_child(name);
104 current_depth += 1;
105 }
106
107 void pop() {
108 if (!enabled)
109 return;
110 current_node = current_node->parent;
111 current_depth -= 1;
112 }
113
114 static ProfilerRecords &get_this_thread_instance() {
115 // Use a raw pointer so that it lives together with the process
116 static thread_local ProfilerRecords *profiler_records = nullptr;
117 if (profiler_records == nullptr) {
118 profiler_records = Profiling::get_instance().get_this_thread_profiler();
119 }
120 return *profiler_records;
121 }
122};
123
124void ProfilerRecords::print(ProfilerRecordNode *node, int depth) {
125 auto make_indent = [depth](int additional) {
126 for (int i = 0; i < depth + additional; i++) {
127 fmt::print(" ");
128 }
129 };
130 using TimeScale = std::pair<real, std::string>;
131
132 auto get_time_scale = [&](real t) -> TimeScale {
133 if (t < 1e-6) {
134 return std::make_pair(1e9_f, "ns");
135 } else if (t < 1e-3) {
136 return std::make_pair(1e6_f, "us");
137 } else if (t < 1) {
138 return std::make_pair(1e3_f, "ms");
139 } else if (t < 60) {
140 return std::make_pair(1_f, " s");
141 } else if (t < 3600) {
142 return std::make_pair(1.0_f / 60_f, " m");
143 } else {
144 return std::make_pair(1.0_f / 3600_f, "h");
145 }
146 };
147
148 auto get_readable_time_with_scale = [&](real t, TimeScale scale) {
149 return fmt::format("{:7.3f} {}", t * scale.first, scale.second);
150 };
151
152 auto get_readable_time = [&](real t) {
153 auto scale = get_time_scale(t);
154 return get_readable_time_with_scale(t, scale);
155 };
156
157 float64 total_time = node->total_time;
158 fmt::color level_color;
159 if (depth == 0)
160 level_color = fmt::color::red;
161 else if (depth == 1)
162 level_color = fmt::color::light_green;
163 else if (depth == 2)
164 level_color = fmt::color::yellow;
165 else if (depth == 3)
166 level_color = fmt::color::light_blue;
167 else if (depth >= 4)
168 level_color = fmt::color::magenta;
169 if (depth == 0) {
170 // Root node only
171 make_indent(0);
172 fmt::print(fg(level_color), "{}\n", node->name.c_str());
173 }
174 if (total_time < 1e-6f) {
175 for (auto &ch : node->childs) {
176 make_indent(1);
177 auto child_time = ch->total_time;
178 auto bulk_statistics =
179 fmt::format("{} {}", get_readable_time(child_time), ch->name);
180 fmt::print(fg(level_color), "{:40}", bulk_statistics);
181 fmt::print(fg(fmt::color::cyan), " [{} x {}]\n", ch->num_samples,
182 get_readable_time_with_scale(
183 ch->get_averaged(), get_time_scale(ch->get_averaged())));
184 print(ch.get(), depth + 1);
185 }
186 } else {
187 TimeScale scale = get_time_scale(total_time);
188 float64 unaccounted = total_time;
189 for (auto &ch : node->childs) {
190 make_indent(1);
191 auto child_time = ch->total_time;
192 std::string bulk_statistics = fmt::format(
193 "{} {:5.2f}% {}", get_readable_time_with_scale(child_time, scale),
194 child_time * 100.0 / total_time, ch->name);
195 fmt::print(fg(level_color), "{:40}", bulk_statistics);
196 fmt::print(fg(fmt::color::cyan), " [{} x {}]\n", ch->num_samples,
197 get_readable_time_with_scale(
198 ch->get_averaged(), get_time_scale(ch->get_averaged())));
199 if (ch->account_tpe) {
200 make_indent(1);
201 fmt::print(" [TPE] {}\n",
202 get_readable_time(ch->total_time));
203 }
204 print(ch.get(), depth + 1);
205 unaccounted -= child_time;
206 }
207 if (!node->childs.empty() && (unaccounted > total_time * 0.005)) {
208 make_indent(1);
209 fmt::print(fg(level_color), "{} {:5.2f}% {}\n",
210 get_readable_time_with_scale(unaccounted, scale),
211 unaccounted * 100.0 / total_time, "[unaccounted]");
212 }
213 }
214}
215
216ScopedProfiler::ScopedProfiler(std::string name, uint64 elements) {
217 start_time_ = Time::get_time();
218 this->name_ = name;
219 this->elements_ = elements;
220 stopped_ = false;
221 ProfilerRecords::get_this_thread_instance().push(name);
222}
223
224void ScopedProfiler::stop() {
225 TI_ASSERT_INFO(!stopped_, "Profiler already stopped.");
226 float64 elapsed = Time::get_time() - start_time_;
227 if ((int64)elements_ != -1) {
228 ProfilerRecords::get_this_thread_instance().insert_sample(elapsed,
229 elements_);
230 } else {
231 ProfilerRecords::get_this_thread_instance().insert_sample(elapsed);
232 }
233 ProfilerRecords::get_this_thread_instance().pop();
234}
235
236void ScopedProfiler::disable() {
237 ProfilerRecords::get_this_thread_instance().enabled = false;
238}
239
240void ScopedProfiler::enable() {
241 ProfilerRecords::get_this_thread_instance().enabled = true;
242}
243
244ScopedProfiler::~ScopedProfiler() {
245 if (!stopped_) {
246 stop();
247 }
248}
249
250Profiling &Profiling::get_instance() {
251 static auto prof = new Profiling;
252 return *prof;
253}
254
255ProfilerRecords *Profiling::get_this_thread_profiler() {
256 std::lock_guard<std::mutex> _(mut_);
257 auto id = std::this_thread::get_id();
258 std::stringstream ss;
259 ss << id;
260 if (profilers_.find(id) == profilers_.end()) {
261 // Note: thread id may be reused
262 profilers_[id] = new ProfilerRecords(fmt::format("thread {}", ss.str()));
263 }
264 return profilers_[id];
265}
266
267void Profiling::print_profile_info() {
268 std::lock_guard<std::mutex> _(mut_);
269 for (auto p : profilers_) {
270 p.second->print();
271 }
272}
273
274void Profiling::clear_profile_info() {
275 std::lock_guard<std::mutex> _(mut_);
276 for (auto p : profilers_) {
277 p.second->clear();
278 }
279}
280
281} // namespace taichi
282