1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/ir/instrument.cc |
22 | * \brief Infrastructure for instrumentation. |
23 | */ |
24 | #include <dmlc/thread_local.h> |
25 | #include <tvm/ir/instrument.h> |
26 | #include <tvm/ir/transform.h> |
27 | #include <tvm/node/repr_printer.h> |
28 | #include <tvm/runtime/registry.h> |
29 | |
30 | #include <stack> |
31 | |
32 | namespace tvm { |
33 | namespace instrument { |
34 | |
35 | /*! |
36 | * \brief Base PassInstrument implementation |
37 | * \sa BasePassInstrument |
38 | */ |
39 | class BasePassInstrumentNode : public PassInstrumentNode { |
40 | public: |
41 | /*! \brief Callback to run when entering PassContext. */ |
42 | runtime::TypedPackedFunc<void()> enter_pass_ctx_callback; |
43 | /*! \brief Callback to run when exiting PassContext. */ |
44 | runtime::TypedPackedFunc<void()> exit_pass_ctx_callback; |
45 | |
46 | /*! \brief Callback determines whether to run a pass or not. */ |
47 | runtime::TypedPackedFunc<bool(const IRModule&, const transform::PassInfo&)> should_run_callback; |
48 | |
49 | /*! \brief Callback to run before a pass. */ |
50 | runtime::TypedPackedFunc<void(const IRModule&, const transform::PassInfo&)> |
51 | run_before_pass_callback; |
52 | /*! \brief Callback to run after a pass. */ |
53 | runtime::TypedPackedFunc<void(const IRModule&, const transform::PassInfo&)> |
54 | run_after_pass_callback; |
55 | |
56 | /*! \brief Instrument when entering PassContext. */ |
57 | void EnterPassContext() const final; |
58 | |
59 | /*! \brief Instrument when exiting PassContext. */ |
60 | void ExitPassContext() const final; |
61 | |
62 | /*! |
63 | * \brief Determine whether to run the pass or not. |
64 | * \param mod The module that an optimization pass runs on. |
65 | * \param info The pass information. |
66 | * |
67 | * \return true to run the pass; false to skip the pass. |
68 | */ |
69 | bool ShouldRun(const IRModule&, const transform::PassInfo& info) const final; |
70 | |
71 | /*! |
72 | * \brief Instrument before pass run. |
73 | * \param mod The module that an optimization pass runs on. |
74 | * \param info The pass information. |
75 | */ |
76 | void RunBeforePass(const IRModule& mod, const transform::PassInfo& info) const final; |
77 | |
78 | /*! |
79 | * \brief Instrument after pass run. |
80 | * |
81 | * \param mod The module that an optimization pass runs on. |
82 | * \param info The pass information. |
83 | */ |
84 | void RunAfterPass(const IRModule& mod, const transform::PassInfo& info) const final; |
85 | |
86 | static constexpr const char* _type_key = "instrument.PassInstrument" ; |
87 | TVM_DECLARE_FINAL_OBJECT_INFO(BasePassInstrumentNode, PassInstrumentNode); |
88 | }; |
89 | |
90 | /*! |
91 | * \brief Managed reference class for BasePassInstrumentNode |
92 | * \sa BasePassInstrumentNode |
93 | */ |
94 | class BasePassInstrument : public PassInstrument { |
95 | public: |
96 | /*! |
97 | * \brief Constructor |
98 | * |
99 | * \param name Name for this instrumentation. |
100 | * |
101 | * |
102 | * \param enter_pass_ctx_callback Callback to call when entering pass context. |
103 | * \param exit_pass_ctx_callback Callback to call when exiting pass context. |
104 | * |
105 | * \param should_run_callback Callback to determine whether pass should run. (return true: enable; |
106 | * return false: disable) |
107 | * |
108 | * \param run_before_pass_callback Callback to call before a pass run. |
109 | * \param run_after_pass_callback Callback to call after a pass run. |
110 | */ |
111 | TVM_DLL BasePassInstrument( |
112 | String name, runtime::TypedPackedFunc<void()> enter_pass_ctx_callback, |
113 | runtime::TypedPackedFunc<void()> exit_pass_ctx_callback, |
114 | runtime::TypedPackedFunc<bool(const IRModule&, const transform::PassInfo&)> |
115 | should_run_callback, |
116 | runtime::TypedPackedFunc<void(const IRModule&, const transform::PassInfo&)> |
117 | run_before_pass_callback, |
118 | runtime::TypedPackedFunc<void(const IRModule&, const transform::PassInfo&)> |
119 | run_after_pass_callback); |
120 | |
121 | TVM_DEFINE_OBJECT_REF_METHODS(BasePassInstrument, PassInstrument, BasePassInstrumentNode); |
122 | }; |
123 | |
124 | BasePassInstrument::BasePassInstrument( |
125 | String name, runtime::TypedPackedFunc<void()> enter_pass_ctx_callback, |
126 | runtime::TypedPackedFunc<void()> exit_pass_ctx_callback, |
127 | runtime::TypedPackedFunc<bool(const IRModule&, const transform::PassInfo&)> should_run_callback, |
128 | runtime::TypedPackedFunc<void(const IRModule&, const transform::PassInfo&)> |
129 | run_before_pass_callback, |
130 | runtime::TypedPackedFunc<void(const IRModule&, const transform::PassInfo&)> |
131 | run_after_pass_callback) { |
132 | auto pi = make_object<BasePassInstrumentNode>(); |
133 | pi->name = std::move(name); |
134 | |
135 | pi->enter_pass_ctx_callback = std::move(enter_pass_ctx_callback); |
136 | pi->exit_pass_ctx_callback = std::move(exit_pass_ctx_callback); |
137 | |
138 | pi->should_run_callback = std::move(should_run_callback); |
139 | |
140 | pi->run_before_pass_callback = std::move(run_before_pass_callback); |
141 | pi->run_after_pass_callback = std::move(run_after_pass_callback); |
142 | |
143 | data_ = std::move(pi); |
144 | } |
145 | |
146 | void BasePassInstrumentNode::EnterPassContext() const { |
147 | if (enter_pass_ctx_callback != nullptr) { |
148 | enter_pass_ctx_callback(); |
149 | } |
150 | } |
151 | |
152 | void BasePassInstrumentNode::ExitPassContext() const { |
153 | if (exit_pass_ctx_callback != nullptr) { |
154 | exit_pass_ctx_callback(); |
155 | } |
156 | } |
157 | |
158 | bool BasePassInstrumentNode::ShouldRun(const IRModule& ir_module, |
159 | const transform::PassInfo& pass_info) const { |
160 | if (should_run_callback == nullptr) { |
161 | return true; |
162 | } |
163 | |
164 | return should_run_callback(ir_module, pass_info); |
165 | } |
166 | |
167 | void BasePassInstrumentNode::RunBeforePass(const IRModule& ir_module, |
168 | const transform::PassInfo& pass_info) const { |
169 | if (run_before_pass_callback != nullptr) { |
170 | run_before_pass_callback(ir_module, pass_info); |
171 | } |
172 | } |
173 | |
174 | void BasePassInstrumentNode::RunAfterPass(const IRModule& ir_module, |
175 | const transform::PassInfo& pass_info) const { |
176 | if (run_after_pass_callback != nullptr) { |
177 | run_after_pass_callback(ir_module, pass_info); |
178 | } |
179 | } |
180 | |
181 | TVM_REGISTER_NODE_TYPE(BasePassInstrumentNode); |
182 | |
183 | TVM_REGISTER_GLOBAL("instrument.PassInstrument" ) |
184 | .set_body_typed( |
185 | [](String name, runtime::TypedPackedFunc<void()> enter_pass_ctx, |
186 | runtime::TypedPackedFunc<void()> exit_pass_ctx, |
187 | runtime::TypedPackedFunc<bool(const IRModule&, const transform::PassInfo&)> should_run, |
188 | runtime::TypedPackedFunc<void(const IRModule&, const transform::PassInfo&)> |
189 | run_before_pass, |
190 | runtime::TypedPackedFunc<void(const IRModule&, const transform::PassInfo&)> |
191 | run_after_pass) { |
192 | return BasePassInstrument(name, enter_pass_ctx, exit_pass_ctx, should_run, |
193 | run_before_pass, run_after_pass); |
194 | }); |
195 | |
196 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
197 | .set_dispatch<BasePassInstrumentNode>([](const ObjectRef& ref, ReprPrinter* p) { |
198 | auto* node = static_cast<const BasePassInstrumentNode*>(ref.get()); |
199 | p->stream << node->name; |
200 | }); |
201 | |
202 | /*! \brief PassProfile stores profiling information for a given pass and its sub-passes. */ |
203 | struct PassProfile { |
204 | // TODO(@altanh): expose PassProfile through TVM Object API |
205 | using Clock = std::chrono::steady_clock; |
206 | using Duration = std::chrono::duration<double, std::micro>; |
207 | using Time = std::chrono::time_point<Clock>; |
208 | |
209 | /*! \brief The name of the pass being profiled. */ |
210 | String name; |
211 | /*! \brief The time when the pass was entered. */ |
212 | Time start; |
213 | /*! \brief The time when the pass completed. */ |
214 | Time end; |
215 | /*! \brief The total duration of the pass, i.e. end - start. */ |
216 | Duration duration; |
217 | /*! \brief PassProfiles for all sub-passes invoked during the execution of the pass. */ |
218 | std::vector<PassProfile> children; |
219 | |
220 | explicit PassProfile(String name) |
221 | : name(name), start(Clock::now()), end(Clock::now()), children() {} |
222 | |
223 | /*! \brief Gets the PassProfile of the currently executing pass. */ |
224 | static PassProfile* Current(); |
225 | /*! \brief Pushes a new PassProfile with the given pass name. */ |
226 | static void EnterPass(String name); |
227 | /*! \brief Pops the current PassProfile. */ |
228 | static void ExitPass(); |
229 | }; |
230 | |
231 | struct PassProfileThreadLocalEntry { |
232 | /*! \brief The placeholder top-level PassProfile. */ |
233 | PassProfile root; |
234 | /*! \brief The stack of PassProfiles for nested passes currently running. */ |
235 | std::stack<PassProfile*> profile_stack; |
236 | |
237 | PassProfileThreadLocalEntry() : root("root" ) {} |
238 | }; |
239 | |
240 | /*! \brief Thread local store to hold the pass profiling data. */ |
241 | typedef dmlc::ThreadLocalStore<PassProfileThreadLocalEntry> PassProfileThreadLocalStore; |
242 | |
243 | void PassProfile::EnterPass(String name) { |
244 | PassProfile* cur = PassProfile::Current(); |
245 | cur->children.emplace_back(name); |
246 | PassProfileThreadLocalStore::Get()->profile_stack.push(&cur->children.back()); |
247 | } |
248 | |
249 | void PassProfile::ExitPass() { |
250 | PassProfile* cur = PassProfile::Current(); |
251 | ICHECK_NE(cur->name, "root" ) << "mismatched enter/exit for pass profiling" ; |
252 | cur->end = PassProfile::Clock::now(); |
253 | cur->duration = std::chrono::duration_cast<PassProfile::Duration>(cur->end - cur->start); |
254 | PassProfileThreadLocalStore::Get()->profile_stack.pop(); |
255 | } |
256 | |
257 | PassProfile* PassProfile::Current() { |
258 | PassProfileThreadLocalEntry* entry = PassProfileThreadLocalStore::Get(); |
259 | if (!entry->profile_stack.empty()) { |
260 | return entry->profile_stack.top(); |
261 | } else { |
262 | return &entry->root; |
263 | } |
264 | } |
265 | |
266 | String RenderPassProfiles() { |
267 | PassProfileThreadLocalEntry* entry = PassProfileThreadLocalStore::Get(); |
268 | CHECK(entry->profile_stack.empty()) << "cannot print pass profile while still in a pass!" ; |
269 | |
270 | if (entry->root.children.empty()) { |
271 | LOG(WARNING) << "no passes have been profiled, did you enable pass profiling?" ; |
272 | return String(); |
273 | } |
274 | |
275 | // (depth, parent_duration, pass) |
276 | std::stack<std::tuple<size_t, PassProfile::Duration, PassProfile*>> profiles; |
277 | |
278 | // push top level passes |
279 | PassProfile::Duration top_dur(0); |
280 | for (auto it = entry->root.children.begin(); it != entry->root.children.end(); ++it) { |
281 | top_dur += it->duration; |
282 | } |
283 | for (auto it = entry->root.children.rbegin(); it != entry->root.children.rend(); ++it) { |
284 | profiles.push(std::make_tuple(0, top_dur, &*it)); |
285 | } |
286 | |
287 | std::ostringstream os; |
288 | os << std::fixed; |
289 | |
290 | while (profiles.size() > 0) { |
291 | auto [depth, parent_duration, profile] = profiles.top(); |
292 | profiles.pop(); |
293 | |
294 | // indent depth |
295 | for (size_t i = 0; i < depth; ++i) { |
296 | os << "\t" ; |
297 | } |
298 | |
299 | // calculate time spent in pass itself (excluding sub-passes), and push children |
300 | PassProfile::Duration self_duration = profile->duration; |
301 | for (auto it = profile->children.rbegin(); it != profile->children.rend(); ++it) { |
302 | self_duration -= it->duration; |
303 | profiles.push(std::make_tuple(depth + 1, profile->duration, &*it)); |
304 | } |
305 | |
306 | double parent_pct = profile->duration.count() / parent_duration.count() * 100.0; |
307 | double total_pct = profile->duration.count() / top_dur.count() * 100.0; |
308 | |
309 | os << profile->name << ": " ; |
310 | os << std::setprecision(0); |
311 | os << profile->duration.count() << "us [" << self_duration.count() << "us] " ; |
312 | os << std::setprecision(2) << "(" << total_pct << "%; " << parent_pct << "%)\n" ; |
313 | } |
314 | |
315 | return os.str(); |
316 | } |
317 | |
318 | TVM_REGISTER_GLOBAL("instrument.RenderTimePassProfiles" ).set_body_typed(RenderPassProfiles); |
319 | |
320 | TVM_REGISTER_GLOBAL("instrument.MakePassTimingInstrument" ).set_body_typed([]() { |
321 | auto run_before_pass = [](const IRModule&, const transform::PassInfo& pass_info) { |
322 | PassProfile::EnterPass(pass_info->name); |
323 | return true; |
324 | }; |
325 | |
326 | auto run_after_pass = [](const IRModule&, const transform::PassInfo& pass_info) { |
327 | PassProfile::ExitPass(); |
328 | }; |
329 | |
330 | auto exit_pass_ctx = []() { PassProfileThreadLocalStore::Get()->root.children.clear(); }; |
331 | |
332 | return BasePassInstrument("PassTimingInstrument" , |
333 | /* enter_pass_ctx */ nullptr, exit_pass_ctx, /* should_run */ nullptr, |
334 | run_before_pass, run_after_pass); |
335 | }); |
336 | |
337 | } // namespace instrument |
338 | } // namespace tvm |
339 | |