1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file profile_instrumentation.cc
22 */
23// Insert profile intrinsic at loop and function level. During codegen,
24// these instruction can be replaced with a call to a target specific handler
25// and can be used to capture profiling information such as processor cycles.
26
27#include <tvm/tir/builtin.h>
28#include <tvm/tir/expr.h>
29#include <tvm/tir/stmt.h>
30#include <tvm/tir/stmt_functor.h>
31#include <tvm/tir/transform.h>
32
33namespace tvm {
34namespace tir {
35namespace lwp {
36
37TVM_REGISTER_PASS_CONFIG_OPTION("tir.lwp_disable_func_prof", Bool);
38TVM_REGISTER_PASS_CONFIG_OPTION("tir.lwp_max_depth", Integer);
39TVM_REGISTER_PASS_CONFIG_OPTION("tir.lwp_min_height", Integer);
40TVM_REGISTER_PASS_CONFIG_OPTION("tir.instr_siblings", Bool);
41TVM_REGISTER_PASS_CONFIG_OPTION("tir.reset_start_id", Bool);
42
43static int32_t start_id = 0;
44
45struct LoopInfo {
46 LoopInfo() = default;
47 LoopInfo(unsigned i, unsigned d, unsigned h = 0) : id(i), depth(d), height(h) {
48 has_siblings = false;
49 has_parallel = false;
50 }
51 unsigned id;
52 int32_t depth;
53 int32_t height;
54 bool has_siblings;
55 // Set to 'true' if ForKind::kParallel is set for the current loop or one of its ancestor
56 bool has_parallel;
57};
58
59using LoopInfoMap = std::unordered_map<const ForNode*, LoopInfo>;
60// Traverse loops depth first and assign them a unique number.
61class LoopAnalyzer : public StmtExprVisitor {
62 public:
63 LoopInfoMap Analyze(const Stmt& stmt) {
64 this->VisitStmt(stmt);
65 return loops;
66 }
67 void VisitStmt_(const ForNode* op) final {
68 LoopInfo loop_info(start_id, 0);
69 start_id++;
70 loop_info.height = TraverseLoop(op->body, 0);
71 loops[op] = loop_info;
72 }
73
74 unsigned TraverseLoop(const Stmt& stmt, unsigned parent_depth, bool has_parallel = false) {
75 if (stmt->IsInstance<SeqStmtNode>()) {
76 std::vector<const ForNode*> siblings;
77 unsigned height = 0;
78 bool has_loop = false;
79 const SeqStmtNode* n = stmt.as<SeqStmtNode>();
80 for (Stmt s : n->seq) {
81 if (s->IsInstance<ForNode>()) {
82 has_loop = true;
83 const ForNode* f = s.as<ForNode>();
84 LoopInfo loop_info(start_id, parent_depth + 1);
85 start_id++;
86 bool parent_parallel = false;
87 if (has_parallel) {
88 loop_info.has_parallel = true;
89 parent_parallel = true;
90 } else if (f->kind == ForKind::kParallel) {
91 // has_parallel for the current loop is being set to 'false' since the
92 // intrinsic is added outside of the loop. The instrumentation isn't
93 // allowed for the subsequent nested loops.
94 loop_info.has_parallel = false;
95 parent_parallel = true;
96 }
97 siblings.push_back(f);
98 height = std::max(height, TraverseLoop(f->body, parent_depth + 1, parent_parallel));
99 loop_info.height = height;
100 loops[f] = loop_info;
101 }
102 }
103 if (siblings.size() > 1) {
104 for (auto* l : siblings) {
105 loops[l].has_siblings = true;
106 }
107 }
108 height = has_loop ? height + 1 : height;
109 return height; // Parent's height : max of all children's height
110 } else if (stmt->IsInstance<IfThenElseNode>()) {
111 const IfThenElseNode* n = stmt.as<IfThenElseNode>();
112 unsigned height = TraverseLoop(n->then_case, parent_depth, has_parallel);
113 if (n->else_case) {
114 height = std::max(height, TraverseLoop(n->else_case.value(), parent_depth, has_parallel));
115 }
116 return height;
117 } else if (stmt->IsInstance<ForNode>()) {
118 const ForNode* f = stmt.as<ForNode>();
119 LoopInfo loop_info(start_id, parent_depth + 1);
120 start_id++;
121 bool parent_parallel = false;
122 if (has_parallel) {
123 loop_info.has_parallel = true;
124 parent_parallel = true;
125 } else if (f->kind == ForKind::kParallel) {
126 // has_parallel for the current loop is being set to 'false' since the
127 // intrinsic is added outside of the loop. The instrumentation isn't
128 // allowed for the subsequent nested loops.
129 loop_info.has_parallel = false;
130 parent_parallel = true;
131 }
132 unsigned height = TraverseLoop(f->body, parent_depth + 1, parent_parallel);
133 loop_info.height = height;
134 loops[f] = loop_info;
135 return height + 1;
136 } else if (stmt->IsInstance<LetStmtNode>()) {
137 const LetStmtNode* n = stmt.as<LetStmtNode>();
138 return TraverseLoop(n->body, parent_depth, has_parallel);
139 } else if (stmt->IsInstance<AttrStmtNode>()) {
140 const AttrStmtNode* n = stmt.as<AttrStmtNode>();
141 return TraverseLoop(n->body, parent_depth, has_parallel);
142 } else if (stmt->IsInstance<AllocateNode>()) {
143 const AllocateNode* n = stmt.as<AllocateNode>();
144 return TraverseLoop(n->body, parent_depth, has_parallel);
145 } else {
146 return 0; // inner-most loop
147 }
148 }
149
150 private:
151 LoopInfoMap loops;
152};
153
154class InstrumentIntrin : public StmtMutator {
155 public:
156 InstrumentIntrin(int32_t max_depth, int32_t min_height, bool instr_siblings)
157 : max_instr_depth_(max_depth),
158 min_instr_height_(min_height),
159 instr_siblings_(instr_siblings) {}
160
161 void GetLoopInfo(PrimFuncNode* op) {
162 LoopAnalyzer analzer;
163 loops_ = std::move(analzer.Analyze(op->body));
164 }
165
166 Stmt VisitStmt_(const SeqStmtNode* op) final {
167 Stmt stmt = StmtMutator::VisitStmt_(op);
168 return SeqStmt::Flatten(stmt);
169 }
170
171 Stmt VisitStmt_(const ForNode* op) final {
172 Stmt stmt = StmtMutator::VisitStmt_(op);
173 if (loops_.count(op) < 1) return stmt;
174
175 LoopInfo loop_info = loops_[op];
176
177 if (loop_info.has_parallel) {
178 return stmt;
179 }
180
181 // Exclude inner-most loops from instrumentation. The inner-most loop has
182 // height '0' and it increases as we move outward in the loop nest.
183 if (loop_info.height < min_instr_height_) {
184 return stmt;
185 }
186
187 // Only instrument loops with a sibling
188 if (instr_siblings_ && !loop_info.has_siblings) {
189 return stmt;
190 }
191
192 // If instr_siblings_ is set, ignore max depth for instrumentation
193 if (!instr_siblings_ && loop_info.depth > max_instr_depth_) {
194 return stmt;
195 }
196 PrimExpr id = static_cast<int32_t>(loop_info.id);
197 PrimExpr start_call = Call(DataType::Handle(), builtin::start_profile_intrinsic(), {id});
198 PrimExpr end_call = Call(DataType::Handle(), builtin::end_profile_intrinsic(), {id});
199 const Stmt start_profile = Evaluate(start_call);
200 const Stmt end_profile = Evaluate(end_call);
201 Stmt new_stmt = SeqStmt({start_profile, stmt, end_profile});
202 return new_stmt;
203 }
204
205 private:
206 LoopInfoMap loops_;
207 int32_t max_instr_depth_;
208 int32_t min_instr_height_;
209 bool instr_siblings_;
210};
211
212class CheckParallelLoops : public StmtExprVisitor {
213 public:
214 bool HasParallelLoops(const Stmt& stmt) {
215 this->VisitStmt(stmt);
216 return has_parallel;
217 }
218
219 private:
220 void VisitStmt_(const ForNode* op) final {
221 if (op->kind == ForKind::kParallel) {
222 has_parallel = true;
223 } else {
224 StmtExprVisitor::VisitStmt_(op);
225 }
226 }
227
228 bool has_parallel = false;
229};
230
231PrimFunc AddProfileBuiltins(PrimFunc func, int32_t max_instr_depth, int32_t min_instr_height,
232 bool instr_siblings, bool disable_func_instrumentation) {
233 auto* func_ptr = func.CopyOnWrite();
234
235 PrimExpr e = start_id++;
236 if (!disable_func_instrumentation) {
237 PrimExpr start_call = Call(DataType::Handle(), builtin::start_profile_intrinsic(), {e});
238 PrimExpr end_call = Call(DataType::Handle(), builtin::end_profile_intrinsic(), {e});
239 const Stmt start_profile = Evaluate(start_call);
240 const Stmt end_profile = Evaluate(end_call);
241 func_ptr->body = SeqStmt({start_profile, std::move(func_ptr->body), end_profile});
242 }
243 InstrumentIntrin p(max_instr_depth, min_instr_height, instr_siblings);
244 p.GetLoopInfo(func_ptr);
245 func_ptr->body = p(std::move(func_ptr->body));
246 return func;
247}
248
249} // namespace lwp
250
251namespace transform {
252Pass InstrumentProfileIntrinsics() {
253 auto pass_func = [](IRModule m, PassContext ctx) {
254 auto* mptr = m.CopyOnWrite();
255
256 // All loops with depth <= max_instr_depth are instrumented. By default,
257 // only outer-most loops are instrumented which has a depth of 0.
258 // In addition, loops with siblings are also instrumented provided
259 // their loop depth is >= min_instr_height. This is done to avoid
260 // instrumenting inner-most loops.
261 auto max_instr_depth = ctx->GetConfig<Integer>("tir.lwp_max_depth", Integer(0)).value();
262 auto min_instr_height = ctx->GetConfig<Integer>("tir.lwp_min_height", Integer(1)).value();
263 bool instr_siblings = ctx->GetConfig<Bool>("tir.instr_siblings", Bool(true)).value();
264 bool disable_func_instrumentation =
265 ctx->GetConfig<Bool>("tir.lwp_disable_func_prof", Bool(false)).value();
266 bool reset_start_id = ctx->GetConfig<Bool>("tir.reset_start_id", Bool(false)).value();
267 if (reset_start_id) lwp::start_id = 0;
268 std::vector<std::pair<GlobalVar, PrimFunc>> updates;
269 for (const auto& kv : mptr->functions) {
270 if (auto* n = kv.second.as<PrimFuncNode>()) {
271 PrimFunc func = GetRef<PrimFunc>(n);
272 auto updated_func =
273 lwp::AddProfileBuiltins(func, max_instr_depth.IntValue(), min_instr_height.IntValue(),
274 instr_siblings, disable_func_instrumentation);
275 updates.push_back({kv.first, updated_func});
276 }
277 }
278 for (const auto& pair : updates) {
279 mptr->AddUnchecked(pair.first, pair.second);
280 }
281 return m;
282 };
283
284 return tvm::transform::CreateModulePass(pass_func, 0, "tir.InstrumentProfileIntrinsics", {});
285}
286
287TVM_REGISTER_GLOBAL("tir.transform.InstrumentProfileIntrinsics")
288 .set_body_typed(InstrumentProfileIntrinsics);
289
290} // namespace transform
291
292} // namespace tir
293} // namespace tvm
294