1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file profile_instrumentation.cc |
22 | */ |
23 | // Insert profile intrinsic at loop and function level. During codegen, |
24 | // these instruction can be replaced with a call to a target specific handler |
25 | // and can be used to capture profiling information such as processor cycles. |
26 | |
27 | #include <tvm/tir/builtin.h> |
28 | #include <tvm/tir/expr.h> |
29 | #include <tvm/tir/stmt.h> |
30 | #include <tvm/tir/stmt_functor.h> |
31 | #include <tvm/tir/transform.h> |
32 | |
33 | namespace tvm { |
34 | namespace tir { |
35 | namespace lwp { |
36 | |
37 | TVM_REGISTER_PASS_CONFIG_OPTION("tir.lwp_disable_func_prof" , Bool); |
38 | TVM_REGISTER_PASS_CONFIG_OPTION("tir.lwp_max_depth" , Integer); |
39 | TVM_REGISTER_PASS_CONFIG_OPTION("tir.lwp_min_height" , Integer); |
40 | TVM_REGISTER_PASS_CONFIG_OPTION("tir.instr_siblings" , Bool); |
41 | TVM_REGISTER_PASS_CONFIG_OPTION("tir.reset_start_id" , Bool); |
42 | |
43 | static int32_t start_id = 0; |
44 | |
45 | struct LoopInfo { |
46 | LoopInfo() = default; |
47 | LoopInfo(unsigned i, unsigned d, unsigned h = 0) : id(i), depth(d), height(h) { |
48 | has_siblings = false; |
49 | has_parallel = false; |
50 | } |
51 | unsigned id; |
52 | int32_t depth; |
53 | int32_t height; |
54 | bool has_siblings; |
55 | // Set to 'true' if ForKind::kParallel is set for the current loop or one of its ancestor |
56 | bool has_parallel; |
57 | }; |
58 | |
59 | using LoopInfoMap = std::unordered_map<const ForNode*, LoopInfo>; |
60 | // Traverse loops depth first and assign them a unique number. |
61 | class LoopAnalyzer : public StmtExprVisitor { |
62 | public: |
63 | LoopInfoMap Analyze(const Stmt& stmt) { |
64 | this->VisitStmt(stmt); |
65 | return loops; |
66 | } |
67 | void VisitStmt_(const ForNode* op) final { |
68 | LoopInfo loop_info(start_id, 0); |
69 | start_id++; |
70 | loop_info.height = TraverseLoop(op->body, 0); |
71 | loops[op] = loop_info; |
72 | } |
73 | |
74 | unsigned TraverseLoop(const Stmt& stmt, unsigned parent_depth, bool has_parallel = false) { |
75 | if (stmt->IsInstance<SeqStmtNode>()) { |
76 | std::vector<const ForNode*> siblings; |
77 | unsigned height = 0; |
78 | bool has_loop = false; |
79 | const SeqStmtNode* n = stmt.as<SeqStmtNode>(); |
80 | for (Stmt s : n->seq) { |
81 | if (s->IsInstance<ForNode>()) { |
82 | has_loop = true; |
83 | const ForNode* f = s.as<ForNode>(); |
84 | LoopInfo loop_info(start_id, parent_depth + 1); |
85 | start_id++; |
86 | bool parent_parallel = false; |
87 | if (has_parallel) { |
88 | loop_info.has_parallel = true; |
89 | parent_parallel = true; |
90 | } else if (f->kind == ForKind::kParallel) { |
91 | // has_parallel for the current loop is being set to 'false' since the |
92 | // intrinsic is added outside of the loop. The instrumentation isn't |
93 | // allowed for the subsequent nested loops. |
94 | loop_info.has_parallel = false; |
95 | parent_parallel = true; |
96 | } |
97 | siblings.push_back(f); |
98 | height = std::max(height, TraverseLoop(f->body, parent_depth + 1, parent_parallel)); |
99 | loop_info.height = height; |
100 | loops[f] = loop_info; |
101 | } |
102 | } |
103 | if (siblings.size() > 1) { |
104 | for (auto* l : siblings) { |
105 | loops[l].has_siblings = true; |
106 | } |
107 | } |
108 | height = has_loop ? height + 1 : height; |
109 | return height; // Parent's height : max of all children's height |
110 | } else if (stmt->IsInstance<IfThenElseNode>()) { |
111 | const IfThenElseNode* n = stmt.as<IfThenElseNode>(); |
112 | unsigned height = TraverseLoop(n->then_case, parent_depth, has_parallel); |
113 | if (n->else_case) { |
114 | height = std::max(height, TraverseLoop(n->else_case.value(), parent_depth, has_parallel)); |
115 | } |
116 | return height; |
117 | } else if (stmt->IsInstance<ForNode>()) { |
118 | const ForNode* f = stmt.as<ForNode>(); |
119 | LoopInfo loop_info(start_id, parent_depth + 1); |
120 | start_id++; |
121 | bool parent_parallel = false; |
122 | if (has_parallel) { |
123 | loop_info.has_parallel = true; |
124 | parent_parallel = true; |
125 | } else if (f->kind == ForKind::kParallel) { |
126 | // has_parallel for the current loop is being set to 'false' since the |
127 | // intrinsic is added outside of the loop. The instrumentation isn't |
128 | // allowed for the subsequent nested loops. |
129 | loop_info.has_parallel = false; |
130 | parent_parallel = true; |
131 | } |
132 | unsigned height = TraverseLoop(f->body, parent_depth + 1, parent_parallel); |
133 | loop_info.height = height; |
134 | loops[f] = loop_info; |
135 | return height + 1; |
136 | } else if (stmt->IsInstance<LetStmtNode>()) { |
137 | const LetStmtNode* n = stmt.as<LetStmtNode>(); |
138 | return TraverseLoop(n->body, parent_depth, has_parallel); |
139 | } else if (stmt->IsInstance<AttrStmtNode>()) { |
140 | const AttrStmtNode* n = stmt.as<AttrStmtNode>(); |
141 | return TraverseLoop(n->body, parent_depth, has_parallel); |
142 | } else if (stmt->IsInstance<AllocateNode>()) { |
143 | const AllocateNode* n = stmt.as<AllocateNode>(); |
144 | return TraverseLoop(n->body, parent_depth, has_parallel); |
145 | } else { |
146 | return 0; // inner-most loop |
147 | } |
148 | } |
149 | |
150 | private: |
151 | LoopInfoMap loops; |
152 | }; |
153 | |
154 | class InstrumentIntrin : public StmtMutator { |
155 | public: |
156 | InstrumentIntrin(int32_t max_depth, int32_t min_height, bool instr_siblings) |
157 | : max_instr_depth_(max_depth), |
158 | min_instr_height_(min_height), |
159 | instr_siblings_(instr_siblings) {} |
160 | |
161 | void GetLoopInfo(PrimFuncNode* op) { |
162 | LoopAnalyzer analzer; |
163 | loops_ = std::move(analzer.Analyze(op->body)); |
164 | } |
165 | |
166 | Stmt VisitStmt_(const SeqStmtNode* op) final { |
167 | Stmt stmt = StmtMutator::VisitStmt_(op); |
168 | return SeqStmt::Flatten(stmt); |
169 | } |
170 | |
171 | Stmt VisitStmt_(const ForNode* op) final { |
172 | Stmt stmt = StmtMutator::VisitStmt_(op); |
173 | if (loops_.count(op) < 1) return stmt; |
174 | |
175 | LoopInfo loop_info = loops_[op]; |
176 | |
177 | if (loop_info.has_parallel) { |
178 | return stmt; |
179 | } |
180 | |
181 | // Exclude inner-most loops from instrumentation. The inner-most loop has |
182 | // height '0' and it increases as we move outward in the loop nest. |
183 | if (loop_info.height < min_instr_height_) { |
184 | return stmt; |
185 | } |
186 | |
187 | // Only instrument loops with a sibling |
188 | if (instr_siblings_ && !loop_info.has_siblings) { |
189 | return stmt; |
190 | } |
191 | |
192 | // If instr_siblings_ is set, ignore max depth for instrumentation |
193 | if (!instr_siblings_ && loop_info.depth > max_instr_depth_) { |
194 | return stmt; |
195 | } |
196 | PrimExpr id = static_cast<int32_t>(loop_info.id); |
197 | PrimExpr start_call = Call(DataType::Handle(), builtin::start_profile_intrinsic(), {id}); |
198 | PrimExpr end_call = Call(DataType::Handle(), builtin::end_profile_intrinsic(), {id}); |
199 | const Stmt start_profile = Evaluate(start_call); |
200 | const Stmt end_profile = Evaluate(end_call); |
201 | Stmt new_stmt = SeqStmt({start_profile, stmt, end_profile}); |
202 | return new_stmt; |
203 | } |
204 | |
205 | private: |
206 | LoopInfoMap loops_; |
207 | int32_t max_instr_depth_; |
208 | int32_t min_instr_height_; |
209 | bool instr_siblings_; |
210 | }; |
211 | |
212 | class CheckParallelLoops : public StmtExprVisitor { |
213 | public: |
214 | bool HasParallelLoops(const Stmt& stmt) { |
215 | this->VisitStmt(stmt); |
216 | return has_parallel; |
217 | } |
218 | |
219 | private: |
220 | void VisitStmt_(const ForNode* op) final { |
221 | if (op->kind == ForKind::kParallel) { |
222 | has_parallel = true; |
223 | } else { |
224 | StmtExprVisitor::VisitStmt_(op); |
225 | } |
226 | } |
227 | |
228 | bool has_parallel = false; |
229 | }; |
230 | |
231 | PrimFunc AddProfileBuiltins(PrimFunc func, int32_t max_instr_depth, int32_t min_instr_height, |
232 | bool instr_siblings, bool disable_func_instrumentation) { |
233 | auto* func_ptr = func.CopyOnWrite(); |
234 | |
235 | PrimExpr e = start_id++; |
236 | if (!disable_func_instrumentation) { |
237 | PrimExpr start_call = Call(DataType::Handle(), builtin::start_profile_intrinsic(), {e}); |
238 | PrimExpr end_call = Call(DataType::Handle(), builtin::end_profile_intrinsic(), {e}); |
239 | const Stmt start_profile = Evaluate(start_call); |
240 | const Stmt end_profile = Evaluate(end_call); |
241 | func_ptr->body = SeqStmt({start_profile, std::move(func_ptr->body), end_profile}); |
242 | } |
243 | InstrumentIntrin p(max_instr_depth, min_instr_height, instr_siblings); |
244 | p.GetLoopInfo(func_ptr); |
245 | func_ptr->body = p(std::move(func_ptr->body)); |
246 | return func; |
247 | } |
248 | |
249 | } // namespace lwp |
250 | |
251 | namespace transform { |
252 | Pass InstrumentProfileIntrinsics() { |
253 | auto pass_func = [](IRModule m, PassContext ctx) { |
254 | auto* mptr = m.CopyOnWrite(); |
255 | |
256 | // All loops with depth <= max_instr_depth are instrumented. By default, |
257 | // only outer-most loops are instrumented which has a depth of 0. |
258 | // In addition, loops with siblings are also instrumented provided |
259 | // their loop depth is >= min_instr_height. This is done to avoid |
260 | // instrumenting inner-most loops. |
261 | auto max_instr_depth = ctx->GetConfig<Integer>("tir.lwp_max_depth" , Integer(0)).value(); |
262 | auto min_instr_height = ctx->GetConfig<Integer>("tir.lwp_min_height" , Integer(1)).value(); |
263 | bool instr_siblings = ctx->GetConfig<Bool>("tir.instr_siblings" , Bool(true)).value(); |
264 | bool disable_func_instrumentation = |
265 | ctx->GetConfig<Bool>("tir.lwp_disable_func_prof" , Bool(false)).value(); |
266 | bool reset_start_id = ctx->GetConfig<Bool>("tir.reset_start_id" , Bool(false)).value(); |
267 | if (reset_start_id) lwp::start_id = 0; |
268 | std::vector<std::pair<GlobalVar, PrimFunc>> updates; |
269 | for (const auto& kv : mptr->functions) { |
270 | if (auto* n = kv.second.as<PrimFuncNode>()) { |
271 | PrimFunc func = GetRef<PrimFunc>(n); |
272 | auto updated_func = |
273 | lwp::AddProfileBuiltins(func, max_instr_depth.IntValue(), min_instr_height.IntValue(), |
274 | instr_siblings, disable_func_instrumentation); |
275 | updates.push_back({kv.first, updated_func}); |
276 | } |
277 | } |
278 | for (const auto& pair : updates) { |
279 | mptr->AddUnchecked(pair.first, pair.second); |
280 | } |
281 | return m; |
282 | }; |
283 | |
284 | return tvm::transform::CreateModulePass(pass_func, 0, "tir.InstrumentProfileIntrinsics" , {}); |
285 | } |
286 | |
287 | TVM_REGISTER_GLOBAL("tir.transform.InstrumentProfileIntrinsics" ) |
288 | .set_body_typed(InstrumentProfileIntrinsics); |
289 | |
290 | } // namespace transform |
291 | |
292 | } // namespace tir |
293 | } // namespace tvm |
294 | |