1 | #include "taichi/system/profiler.h" |
2 | #include "spdlog/fmt/bundled/color.h" |
3 | |
4 | namespace taichi { |
5 | |
6 | // A profiler's records form a tree structure |
7 | struct ProfilerRecordNode { |
8 | std::vector<std::unique_ptr<ProfilerRecordNode>> childs; |
9 | ProfilerRecordNode *parent; |
10 | std::string name; |
11 | float64 total_time; |
12 | // Time per element |
13 | bool account_tpe; |
14 | uint64 total_elements; |
15 | int64 num_samples; |
16 | |
17 | ProfilerRecordNode(const std::string &name, ProfilerRecordNode *parent) { |
18 | this->name = name; |
19 | this->parent = parent; |
20 | this->total_time = 0.0_f64; |
21 | this->num_samples = 0ll; |
22 | this->total_elements = 0ll; |
23 | this->account_tpe = false; |
24 | } |
25 | |
26 | void insert_sample(float64 sample) { |
27 | num_samples += 1; |
28 | total_time += sample; |
29 | } |
30 | |
31 | void insert_sample(float64 sample, uint64 elements) { |
32 | account_tpe = true; |
33 | num_samples += 1; |
34 | total_time += sample; |
35 | total_elements += elements; |
36 | } |
37 | |
38 | float64 get_averaged() const { |
39 | return total_time / (float64)std::max(num_samples, int64(1)); |
40 | } |
41 | |
42 | float64 get_averaged_tpe() const { |
43 | TI_ASSERT(account_tpe); |
44 | return total_time / (float64)total_elements; |
45 | } |
46 | |
47 | ProfilerRecordNode *get_child(const std::string &name) { |
48 | for (auto &ch : childs) { |
49 | if (ch->name == name) { |
50 | return ch.get(); |
51 | } |
52 | } |
53 | childs.push_back(std::make_unique<ProfilerRecordNode>(name, this)); |
54 | return childs.back().get(); |
55 | } |
56 | }; |
57 | |
58 | class ProfilerRecords { |
59 | public: |
60 | std::unique_ptr<ProfilerRecordNode> root; |
61 | ProfilerRecordNode *current_node; |
62 | int current_depth; |
63 | bool enabled; |
64 | |
65 | explicit ProfilerRecords(const std::string &name) { |
66 | root = std::make_unique<ProfilerRecordNode>( |
67 | fmt::format("[Profiler {}]" , name), nullptr); |
68 | current_node = root.get(); |
69 | current_depth = 0; // depth(root) = 0 |
70 | enabled = true; |
71 | } |
72 | |
73 | void clear() { |
74 | root->childs.clear(); |
75 | current_node = root.get(); |
76 | current_depth = 0; |
77 | enabled = true; |
78 | } |
79 | |
80 | void print(ProfilerRecordNode *node, int depth); |
81 | |
82 | void print() { |
83 | fmt::print(fg(fmt::color::cyan), std::string(80, '>') + "\n" ); |
84 | print(root.get(), 0); |
85 | fmt::print(fg(fmt::color::cyan), std::string(80, '>') + "\n" ); |
86 | } |
87 | |
88 | void insert_sample(float64 time) { |
89 | if (!enabled) |
90 | return; |
91 | current_node->insert_sample(time); |
92 | } |
93 | |
94 | void insert_sample(float64 time, uint64 tpe) { |
95 | if (!enabled) |
96 | return; |
97 | current_node->insert_sample(time, tpe); |
98 | } |
99 | |
100 | void push(const std::string name) { |
101 | if (!enabled) |
102 | return; |
103 | current_node = current_node->get_child(name); |
104 | current_depth += 1; |
105 | } |
106 | |
107 | void pop() { |
108 | if (!enabled) |
109 | return; |
110 | current_node = current_node->parent; |
111 | current_depth -= 1; |
112 | } |
113 | |
114 | static ProfilerRecords &get_this_thread_instance() { |
115 | // Use a raw pointer so that it lives together with the process |
116 | static thread_local ProfilerRecords *profiler_records = nullptr; |
117 | if (profiler_records == nullptr) { |
118 | profiler_records = Profiling::get_instance().get_this_thread_profiler(); |
119 | } |
120 | return *profiler_records; |
121 | } |
122 | }; |
123 | |
124 | void ProfilerRecords::print(ProfilerRecordNode *node, int depth) { |
125 | auto make_indent = [depth](int additional) { |
126 | for (int i = 0; i < depth + additional; i++) { |
127 | fmt::print(" " ); |
128 | } |
129 | }; |
130 | using TimeScale = std::pair<real, std::string>; |
131 | |
132 | auto get_time_scale = [&](real t) -> TimeScale { |
133 | if (t < 1e-6) { |
134 | return std::make_pair(1e9_f, "ns" ); |
135 | } else if (t < 1e-3) { |
136 | return std::make_pair(1e6_f, "us" ); |
137 | } else if (t < 1) { |
138 | return std::make_pair(1e3_f, "ms" ); |
139 | } else if (t < 60) { |
140 | return std::make_pair(1_f, " s" ); |
141 | } else if (t < 3600) { |
142 | return std::make_pair(1.0_f / 60_f, " m" ); |
143 | } else { |
144 | return std::make_pair(1.0_f / 3600_f, "h" ); |
145 | } |
146 | }; |
147 | |
148 | auto get_readable_time_with_scale = [&](real t, TimeScale scale) { |
149 | return fmt::format("{:7.3f} {}" , t * scale.first, scale.second); |
150 | }; |
151 | |
152 | auto get_readable_time = [&](real t) { |
153 | auto scale = get_time_scale(t); |
154 | return get_readable_time_with_scale(t, scale); |
155 | }; |
156 | |
157 | float64 total_time = node->total_time; |
158 | fmt::color level_color; |
159 | if (depth == 0) |
160 | level_color = fmt::color::red; |
161 | else if (depth == 1) |
162 | level_color = fmt::color::light_green; |
163 | else if (depth == 2) |
164 | level_color = fmt::color::yellow; |
165 | else if (depth == 3) |
166 | level_color = fmt::color::light_blue; |
167 | else if (depth >= 4) |
168 | level_color = fmt::color::magenta; |
169 | if (depth == 0) { |
170 | // Root node only |
171 | make_indent(0); |
172 | fmt::print(fg(level_color), "{}\n" , node->name.c_str()); |
173 | } |
174 | if (total_time < 1e-6f) { |
175 | for (auto &ch : node->childs) { |
176 | make_indent(1); |
177 | auto child_time = ch->total_time; |
178 | auto bulk_statistics = |
179 | fmt::format("{} {}" , get_readable_time(child_time), ch->name); |
180 | fmt::print(fg(level_color), "{:40}" , bulk_statistics); |
181 | fmt::print(fg(fmt::color::cyan), " [{} x {}]\n" , ch->num_samples, |
182 | get_readable_time_with_scale( |
183 | ch->get_averaged(), get_time_scale(ch->get_averaged()))); |
184 | print(ch.get(), depth + 1); |
185 | } |
186 | } else { |
187 | TimeScale scale = get_time_scale(total_time); |
188 | float64 unaccounted = total_time; |
189 | for (auto &ch : node->childs) { |
190 | make_indent(1); |
191 | auto child_time = ch->total_time; |
192 | std::string bulk_statistics = fmt::format( |
193 | "{} {:5.2f}% {}" , get_readable_time_with_scale(child_time, scale), |
194 | child_time * 100.0 / total_time, ch->name); |
195 | fmt::print(fg(level_color), "{:40}" , bulk_statistics); |
196 | fmt::print(fg(fmt::color::cyan), " [{} x {}]\n" , ch->num_samples, |
197 | get_readable_time_with_scale( |
198 | ch->get_averaged(), get_time_scale(ch->get_averaged()))); |
199 | if (ch->account_tpe) { |
200 | make_indent(1); |
201 | fmt::print(" [TPE] {}\n" , |
202 | get_readable_time(ch->total_time)); |
203 | } |
204 | print(ch.get(), depth + 1); |
205 | unaccounted -= child_time; |
206 | } |
207 | if (!node->childs.empty() && (unaccounted > total_time * 0.005)) { |
208 | make_indent(1); |
209 | fmt::print(fg(level_color), "{} {:5.2f}% {}\n" , |
210 | get_readable_time_with_scale(unaccounted, scale), |
211 | unaccounted * 100.0 / total_time, "[unaccounted]" ); |
212 | } |
213 | } |
214 | } |
215 | |
216 | ScopedProfiler::ScopedProfiler(std::string name, uint64 elements) { |
217 | start_time_ = Time::get_time(); |
218 | this->name_ = name; |
219 | this->elements_ = elements; |
220 | stopped_ = false; |
221 | ProfilerRecords::get_this_thread_instance().push(name); |
222 | } |
223 | |
224 | void ScopedProfiler::stop() { |
225 | TI_ASSERT_INFO(!stopped_, "Profiler already stopped." ); |
226 | float64 elapsed = Time::get_time() - start_time_; |
227 | if ((int64)elements_ != -1) { |
228 | ProfilerRecords::get_this_thread_instance().insert_sample(elapsed, |
229 | elements_); |
230 | } else { |
231 | ProfilerRecords::get_this_thread_instance().insert_sample(elapsed); |
232 | } |
233 | ProfilerRecords::get_this_thread_instance().pop(); |
234 | } |
235 | |
236 | void ScopedProfiler::disable() { |
237 | ProfilerRecords::get_this_thread_instance().enabled = false; |
238 | } |
239 | |
240 | void ScopedProfiler::enable() { |
241 | ProfilerRecords::get_this_thread_instance().enabled = true; |
242 | } |
243 | |
244 | ScopedProfiler::~ScopedProfiler() { |
245 | if (!stopped_) { |
246 | stop(); |
247 | } |
248 | } |
249 | |
250 | Profiling &Profiling::get_instance() { |
251 | static auto prof = new Profiling; |
252 | return *prof; |
253 | } |
254 | |
255 | ProfilerRecords *Profiling::get_this_thread_profiler() { |
256 | std::lock_guard<std::mutex> _(mut_); |
257 | auto id = std::this_thread::get_id(); |
258 | std::stringstream ss; |
259 | ss << id; |
260 | if (profilers_.find(id) == profilers_.end()) { |
261 | // Note: thread id may be reused |
262 | profilers_[id] = new ProfilerRecords(fmt::format("thread {}" , ss.str())); |
263 | } |
264 | return profilers_[id]; |
265 | } |
266 | |
267 | void Profiling::print_profile_info() { |
268 | std::lock_guard<std::mutex> _(mut_); |
269 | for (auto p : profilers_) { |
270 | p.second->print(); |
271 | } |
272 | } |
273 | |
274 | void Profiling::clear_profile_info() { |
275 | std::lock_guard<std::mutex> _(mut_); |
276 | for (auto p : profilers_) { |
277 | p.second->clear(); |
278 | } |
279 | } |
280 | |
281 | } // namespace taichi |
282 | |