1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/ir/instrument.cc
22 * \brief Infrastructure for instrumentation.
23 */
24#include <dmlc/thread_local.h>
25#include <tvm/ir/instrument.h>
26#include <tvm/ir/transform.h>
27#include <tvm/node/repr_printer.h>
28#include <tvm/runtime/registry.h>
29
30#include <stack>
31
32namespace tvm {
33namespace instrument {
34
35/*!
36 * \brief Base PassInstrument implementation
37 * \sa BasePassInstrument
38 */
39class BasePassInstrumentNode : public PassInstrumentNode {
40 public:
41 /*! \brief Callback to run when entering PassContext. */
42 runtime::TypedPackedFunc<void()> enter_pass_ctx_callback;
43 /*! \brief Callback to run when exiting PassContext. */
44 runtime::TypedPackedFunc<void()> exit_pass_ctx_callback;
45
46 /*! \brief Callback determines whether to run a pass or not. */
47 runtime::TypedPackedFunc<bool(const IRModule&, const transform::PassInfo&)> should_run_callback;
48
49 /*! \brief Callback to run before a pass. */
50 runtime::TypedPackedFunc<void(const IRModule&, const transform::PassInfo&)>
51 run_before_pass_callback;
52 /*! \brief Callback to run after a pass. */
53 runtime::TypedPackedFunc<void(const IRModule&, const transform::PassInfo&)>
54 run_after_pass_callback;
55
56 /*! \brief Instrument when entering PassContext. */
57 void EnterPassContext() const final;
58
59 /*! \brief Instrument when exiting PassContext. */
60 void ExitPassContext() const final;
61
62 /*!
63 * \brief Determine whether to run the pass or not.
64 * \param mod The module that an optimization pass runs on.
65 * \param info The pass information.
66 *
67 * \return true to run the pass; false to skip the pass.
68 */
69 bool ShouldRun(const IRModule&, const transform::PassInfo& info) const final;
70
71 /*!
72 * \brief Instrument before pass run.
73 * \param mod The module that an optimization pass runs on.
74 * \param info The pass information.
75 */
76 void RunBeforePass(const IRModule& mod, const transform::PassInfo& info) const final;
77
78 /*!
79 * \brief Instrument after pass run.
80 *
81 * \param mod The module that an optimization pass runs on.
82 * \param info The pass information.
83 */
84 void RunAfterPass(const IRModule& mod, const transform::PassInfo& info) const final;
85
86 static constexpr const char* _type_key = "instrument.PassInstrument";
87 TVM_DECLARE_FINAL_OBJECT_INFO(BasePassInstrumentNode, PassInstrumentNode);
88};
89
90/*!
91 * \brief Managed reference class for BasePassInstrumentNode
92 * \sa BasePassInstrumentNode
93 */
94class BasePassInstrument : public PassInstrument {
95 public:
96 /*!
97 * \brief Constructor
98 *
99 * \param name Name for this instrumentation.
100 *
101 *
102 * \param enter_pass_ctx_callback Callback to call when entering pass context.
103 * \param exit_pass_ctx_callback Callback to call when exiting pass context.
104 *
105 * \param should_run_callback Callback to determine whether pass should run. (return true: enable;
106 * return false: disable)
107 *
108 * \param run_before_pass_callback Callback to call before a pass run.
109 * \param run_after_pass_callback Callback to call after a pass run.
110 */
111 TVM_DLL BasePassInstrument(
112 String name, runtime::TypedPackedFunc<void()> enter_pass_ctx_callback,
113 runtime::TypedPackedFunc<void()> exit_pass_ctx_callback,
114 runtime::TypedPackedFunc<bool(const IRModule&, const transform::PassInfo&)>
115 should_run_callback,
116 runtime::TypedPackedFunc<void(const IRModule&, const transform::PassInfo&)>
117 run_before_pass_callback,
118 runtime::TypedPackedFunc<void(const IRModule&, const transform::PassInfo&)>
119 run_after_pass_callback);
120
121 TVM_DEFINE_OBJECT_REF_METHODS(BasePassInstrument, PassInstrument, BasePassInstrumentNode);
122};
123
124BasePassInstrument::BasePassInstrument(
125 String name, runtime::TypedPackedFunc<void()> enter_pass_ctx_callback,
126 runtime::TypedPackedFunc<void()> exit_pass_ctx_callback,
127 runtime::TypedPackedFunc<bool(const IRModule&, const transform::PassInfo&)> should_run_callback,
128 runtime::TypedPackedFunc<void(const IRModule&, const transform::PassInfo&)>
129 run_before_pass_callback,
130 runtime::TypedPackedFunc<void(const IRModule&, const transform::PassInfo&)>
131 run_after_pass_callback) {
132 auto pi = make_object<BasePassInstrumentNode>();
133 pi->name = std::move(name);
134
135 pi->enter_pass_ctx_callback = std::move(enter_pass_ctx_callback);
136 pi->exit_pass_ctx_callback = std::move(exit_pass_ctx_callback);
137
138 pi->should_run_callback = std::move(should_run_callback);
139
140 pi->run_before_pass_callback = std::move(run_before_pass_callback);
141 pi->run_after_pass_callback = std::move(run_after_pass_callback);
142
143 data_ = std::move(pi);
144}
145
146void BasePassInstrumentNode::EnterPassContext() const {
147 if (enter_pass_ctx_callback != nullptr) {
148 enter_pass_ctx_callback();
149 }
150}
151
152void BasePassInstrumentNode::ExitPassContext() const {
153 if (exit_pass_ctx_callback != nullptr) {
154 exit_pass_ctx_callback();
155 }
156}
157
158bool BasePassInstrumentNode::ShouldRun(const IRModule& ir_module,
159 const transform::PassInfo& pass_info) const {
160 if (should_run_callback == nullptr) {
161 return true;
162 }
163
164 return should_run_callback(ir_module, pass_info);
165}
166
167void BasePassInstrumentNode::RunBeforePass(const IRModule& ir_module,
168 const transform::PassInfo& pass_info) const {
169 if (run_before_pass_callback != nullptr) {
170 run_before_pass_callback(ir_module, pass_info);
171 }
172}
173
174void BasePassInstrumentNode::RunAfterPass(const IRModule& ir_module,
175 const transform::PassInfo& pass_info) const {
176 if (run_after_pass_callback != nullptr) {
177 run_after_pass_callback(ir_module, pass_info);
178 }
179}
180
181TVM_REGISTER_NODE_TYPE(BasePassInstrumentNode);
182
183TVM_REGISTER_GLOBAL("instrument.PassInstrument")
184 .set_body_typed(
185 [](String name, runtime::TypedPackedFunc<void()> enter_pass_ctx,
186 runtime::TypedPackedFunc<void()> exit_pass_ctx,
187 runtime::TypedPackedFunc<bool(const IRModule&, const transform::PassInfo&)> should_run,
188 runtime::TypedPackedFunc<void(const IRModule&, const transform::PassInfo&)>
189 run_before_pass,
190 runtime::TypedPackedFunc<void(const IRModule&, const transform::PassInfo&)>
191 run_after_pass) {
192 return BasePassInstrument(name, enter_pass_ctx, exit_pass_ctx, should_run,
193 run_before_pass, run_after_pass);
194 });
195
196TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
197 .set_dispatch<BasePassInstrumentNode>([](const ObjectRef& ref, ReprPrinter* p) {
198 auto* node = static_cast<const BasePassInstrumentNode*>(ref.get());
199 p->stream << node->name;
200 });
201
202/*! \brief PassProfile stores profiling information for a given pass and its sub-passes. */
203struct PassProfile {
204 // TODO(@altanh): expose PassProfile through TVM Object API
205 using Clock = std::chrono::steady_clock;
206 using Duration = std::chrono::duration<double, std::micro>;
207 using Time = std::chrono::time_point<Clock>;
208
209 /*! \brief The name of the pass being profiled. */
210 String name;
211 /*! \brief The time when the pass was entered. */
212 Time start;
213 /*! \brief The time when the pass completed. */
214 Time end;
215 /*! \brief The total duration of the pass, i.e. end - start. */
216 Duration duration;
217 /*! \brief PassProfiles for all sub-passes invoked during the execution of the pass. */
218 std::vector<PassProfile> children;
219
220 explicit PassProfile(String name)
221 : name(name), start(Clock::now()), end(Clock::now()), children() {}
222
223 /*! \brief Gets the PassProfile of the currently executing pass. */
224 static PassProfile* Current();
225 /*! \brief Pushes a new PassProfile with the given pass name. */
226 static void EnterPass(String name);
227 /*! \brief Pops the current PassProfile. */
228 static void ExitPass();
229};
230
231struct PassProfileThreadLocalEntry {
232 /*! \brief The placeholder top-level PassProfile. */
233 PassProfile root;
234 /*! \brief The stack of PassProfiles for nested passes currently running. */
235 std::stack<PassProfile*> profile_stack;
236
237 PassProfileThreadLocalEntry() : root("root") {}
238};
239
240/*! \brief Thread local store to hold the pass profiling data. */
241typedef dmlc::ThreadLocalStore<PassProfileThreadLocalEntry> PassProfileThreadLocalStore;
242
243void PassProfile::EnterPass(String name) {
244 PassProfile* cur = PassProfile::Current();
245 cur->children.emplace_back(name);
246 PassProfileThreadLocalStore::Get()->profile_stack.push(&cur->children.back());
247}
248
249void PassProfile::ExitPass() {
250 PassProfile* cur = PassProfile::Current();
251 ICHECK_NE(cur->name, "root") << "mismatched enter/exit for pass profiling";
252 cur->end = PassProfile::Clock::now();
253 cur->duration = std::chrono::duration_cast<PassProfile::Duration>(cur->end - cur->start);
254 PassProfileThreadLocalStore::Get()->profile_stack.pop();
255}
256
257PassProfile* PassProfile::Current() {
258 PassProfileThreadLocalEntry* entry = PassProfileThreadLocalStore::Get();
259 if (!entry->profile_stack.empty()) {
260 return entry->profile_stack.top();
261 } else {
262 return &entry->root;
263 }
264}
265
266String RenderPassProfiles() {
267 PassProfileThreadLocalEntry* entry = PassProfileThreadLocalStore::Get();
268 CHECK(entry->profile_stack.empty()) << "cannot print pass profile while still in a pass!";
269
270 if (entry->root.children.empty()) {
271 LOG(WARNING) << "no passes have been profiled, did you enable pass profiling?";
272 return String();
273 }
274
275 // (depth, parent_duration, pass)
276 std::stack<std::tuple<size_t, PassProfile::Duration, PassProfile*>> profiles;
277
278 // push top level passes
279 PassProfile::Duration top_dur(0);
280 for (auto it = entry->root.children.begin(); it != entry->root.children.end(); ++it) {
281 top_dur += it->duration;
282 }
283 for (auto it = entry->root.children.rbegin(); it != entry->root.children.rend(); ++it) {
284 profiles.push(std::make_tuple(0, top_dur, &*it));
285 }
286
287 std::ostringstream os;
288 os << std::fixed;
289
290 while (profiles.size() > 0) {
291 auto [depth, parent_duration, profile] = profiles.top();
292 profiles.pop();
293
294 // indent depth
295 for (size_t i = 0; i < depth; ++i) {
296 os << "\t";
297 }
298
299 // calculate time spent in pass itself (excluding sub-passes), and push children
300 PassProfile::Duration self_duration = profile->duration;
301 for (auto it = profile->children.rbegin(); it != profile->children.rend(); ++it) {
302 self_duration -= it->duration;
303 profiles.push(std::make_tuple(depth + 1, profile->duration, &*it));
304 }
305
306 double parent_pct = profile->duration.count() / parent_duration.count() * 100.0;
307 double total_pct = profile->duration.count() / top_dur.count() * 100.0;
308
309 os << profile->name << ": ";
310 os << std::setprecision(0);
311 os << profile->duration.count() << "us [" << self_duration.count() << "us] ";
312 os << std::setprecision(2) << "(" << total_pct << "%; " << parent_pct << "%)\n";
313 }
314
315 return os.str();
316}
317
318TVM_REGISTER_GLOBAL("instrument.RenderTimePassProfiles").set_body_typed(RenderPassProfiles);
319
320TVM_REGISTER_GLOBAL("instrument.MakePassTimingInstrument").set_body_typed([]() {
321 auto run_before_pass = [](const IRModule&, const transform::PassInfo& pass_info) {
322 PassProfile::EnterPass(pass_info->name);
323 return true;
324 };
325
326 auto run_after_pass = [](const IRModule&, const transform::PassInfo& pass_info) {
327 PassProfile::ExitPass();
328 };
329
330 auto exit_pass_ctx = []() { PassProfileThreadLocalStore::Get()->root.children.clear(); };
331
332 return BasePassInstrument("PassTimingInstrument",
333 /* enter_pass_ctx */ nullptr, exit_pass_ctx, /* should_run */ nullptr,
334 run_before_pass, run_after_pass);
335});
336
337} // namespace instrument
338} // namespace tvm
339