1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tir/analysis/usmp/extract_buffer_info.cc |
22 | * |
23 | * \brief This analysis pass consumes a TIR IRModule with a main function |
24 | * that defines a ordering in the callees to operators and produces BufferInfo |
25 | * objects that contains information about tir.allocate nodes and liveness |
26 | * conflicts between other tir.allocate nodes. |
27 | */ |
28 | #include <tvm/arith/analyzer.h> |
29 | #include <tvm/relay/executor.h> |
30 | #include <tvm/runtime/device_api.h> |
31 | #include <tvm/tir/builtin.h> |
32 | #include <tvm/tir/function.h> |
33 | #include <tvm/tir/stmt_functor.h> |
34 | #include <tvm/tir/usmp/analysis.h> |
35 | #include <tvm/tir/usmp/utils.h> |
36 | |
37 | #include <stack> |
38 | |
39 | #include "../../../runtime/thread_storage_scope.h" |
40 | |
41 | namespace tvm { |
42 | namespace tir { |
43 | namespace usmp { |
44 | |
45 | /*! |
46 | * \brief The visitor class to obtain buffer information |
47 | * |
48 | * The visitor would initiate the traversal from the main |
49 | * function and visits into the operator PrimFuncs. It will |
50 | * crate unique BufferInfo objects for each Allocate node. |
51 | * |
52 | * Every time the buffer variable of the allocate node is referenced |
53 | * it will be recorded using the stmt index. However, note that |
54 | * the same buffer variable could be references multiple times |
55 | * from different calls. Thereafter, a sweep is done on all the |
56 | * BufferInfo objects using the per-call liveness events. In the sweep, |
57 | * The BufferInfo objects that are live together will be recorded as |
58 | * mutual conflicts of each other. |
59 | */ |
60 | class : public StmtExprVisitor { |
61 | public: |
62 | explicit (const IRModule& module) : module_(module) { |
63 | for (const auto& gv_func : module_->functions) { |
64 | if (gv_func.second->IsInstance<PrimFuncNode>()) { |
65 | functions_.Set(gv_func.first->name_hint, Downcast<PrimFunc>(gv_func.second)); |
66 | } |
67 | } |
68 | // Pushing a scope info for the initial body of the main function |
69 | scope_stack_.push(ScopeInfo()); |
70 | } |
71 | BufferInfoAnalysis operator()(const PrimFunc& func); |
72 | |
73 | private: |
74 | void VisitStmt(const Stmt& n) override; |
75 | void VisitStmt_(const AllocateNode* op) override; |
76 | void VisitStmt_(const AllocateConstNode* op) override; |
77 | void VisitExpr_(const CallNode* op) override; |
78 | void VisitExpr_(const VarNode* op) override; |
79 | void VisitExpr_(const BufferLoadNode* op) override; |
80 | void VisitStmt_(const BufferStoreNode* op) override; |
81 | void VisitStmt_(const ForNode* op) override; |
82 | |
83 | void UpdateAliases(const Array<PrimExpr>& args, const PrimFunc& func); |
84 | void RecordAllocateNodeInfo(const AllocateNode* op); |
85 | void RecordAllocateConstNodeInfo(const AllocateConstNode* op); |
86 | void VisitPrimFunc(const PrimFunc& func, const Call& call); |
87 | |
88 | /*! |
89 | * \brief Maintains the mapping of BufferInfo to their associated TIR Statements. |
90 | */ |
91 | Map<BufferInfo, tir::Stmt> ; |
92 | /*! |
93 | * \brief Records the order of calls in the main for stability. |
94 | */ |
95 | std::vector<Call> ; |
96 | /*! |
97 | * \brief Lookup to avoid adding duplicates to `call_order_`. |
98 | */ |
99 | std::unordered_set<Call, ObjectPtrHash, ObjectPtrEqual> ; |
100 | /*! |
101 | * \brief Records first access in-terms of Stmts to each buffer per call |
102 | * |
103 | * This is because multiple calls could happen to the same PrimFunc. |
104 | */ |
105 | std::unordered_map<Call, Map<tir::Stmt, Integer>, ObjectPtrHash, ObjectPtrEqual> |
106 | ; |
107 | /*! |
108 | * \brief Records last access in-terms of Stmts to each buffer per call |
109 | * |
110 | * This is because multiple calls could happen to the same PrimFunc. |
111 | */ |
112 | std::unordered_map<Call, Map<tir::Stmt, Integer>, ObjectPtrHash, ObjectPtrEqual> |
113 | ; |
114 | |
115 | /*! |
116 | * \brief This structure contains information regarding a Allocate node. |
117 | */ |
118 | struct { |
119 | tir::Stmt ; |
120 | PrimFunc ; |
121 | Call ; |
122 | }; |
123 | |
124 | /*! |
125 | * \brief Maintains the mapping of buffer variable to their allocate nodes to ensure |
126 | * that only one BufferInfo object is created. |
127 | */ |
128 | std::unordered_map<tir::Var, AllocateInfo, ObjectPtrHash, ObjectPtrEqual> ; |
129 | /*! |
130 | * \brief Indicates a count of stmts visited so far to use as a metric of liveness |
131 | */ |
132 | int = 0; |
133 | /*! |
134 | * \brief This structure is supposed to contain information around the scope |
135 | * the visitor is currently in. |
136 | */ |
137 | struct { |
138 | /*! |
139 | * \brief We need to record access per call |
140 | */ |
141 | Call ; |
142 | /*! |
143 | * \brief Having access to PrimFunc metadata is useful |
144 | */ |
145 | PrimFunc ; |
146 | /*! |
147 | * \brief We currently support only serial for loops. Therefore |
148 | * need to know what kind of for loop the visitor is in. |
149 | */ |
150 | For ; |
151 | /*! |
152 | * \brief We record the live allocate_nodes because once in loops |
153 | * the liveness range has to be extended to the whole of the nested |
154 | * loops structure. |
155 | */ |
156 | std::unordered_set<Allocate, ObjectPtrHash, ObjectPtrEqual> ; |
157 | /* |
158 | * \brief We record the live allocate_const_nodes because once in loops |
159 | * the liveness range has to be extended to the whole of the nested |
160 | * loops structure. |
161 | */ |
162 | std::unordered_set<AllocateConst, ObjectPtrHash, ObjectPtrEqual> ; |
163 | /*! |
164 | * \brief This is recorded to extend the liveness of all allocates within |
165 | * nested loop structure. |
166 | */ |
167 | Integer ; |
168 | }; |
169 | std::stack<ScopeInfo> ; |
170 | |
171 | /*! |
172 | * \brief A liveness event is an event that when |
173 | * traversing the tir.Stmts where tir.allocate node |
174 | * begins or ceases to be Live. This particular struct |
175 | * is used to solve interval overlap problem using |
176 | * a sweep-line algorithm. For that, we need to record |
177 | * where the liveness event occurred in a chronological |
178 | * order. |
179 | */ |
180 | enum { = 0, = 1 }; |
181 | struct { |
182 | size_t ; |
183 | LivenessEventType ; |
184 | BufferInfo ; |
185 | bool (const LivenessEvent& other) { |
186 | if (tick == other.tick && le_type == other.le_type && buffer_info == other.buffer_info) { |
187 | return true; |
188 | } |
189 | return false; |
190 | } |
191 | }; |
192 | /*! |
193 | * \brief We need to create unique buffer name is the same name is used in |
194 | * two allocate nodes for clarity for memory planning algorithms. |
195 | */ |
196 | std::string GetUniqueBufferName(std::string name); |
197 | |
198 | /*! |
199 | * \brief This is per buffer name counter to aid the generating the above |
200 | * unique name. |
201 | */ |
202 | std::unordered_map<std::string, int> ; |
203 | /*! |
204 | * \brief The TIR main function calls by name to PrimFuncs to be able to |
205 | * support BYOC. Therefore, this Map records functions that are present |
206 | * in the IRModule by name/ |
207 | */ |
208 | Map<String, PrimFunc> ; |
209 | /*! |
210 | * \brief The IRModule being analyzed. |
211 | */ |
212 | IRModule ; |
213 | }; |
214 | |
215 | std::string BufferInfoExtractor::(std::string name) { |
216 | if (buffer_names.find(name) == buffer_names.end()) { |
217 | buffer_names[name] = 1; |
218 | return name; |
219 | } else { |
220 | buffer_names[name] = buffer_names[name] + 1; |
221 | return name + std::to_string(buffer_names[name]); |
222 | } |
223 | } |
224 | |
225 | void BufferInfoExtractor::(const Stmt& n) { |
226 | current_stmt_idx_ += 1; |
227 | StmtExprVisitor::VisitStmt(n); |
228 | } |
229 | |
230 | void BufferInfoExtractor::(const AllocateNode* op) { |
231 | auto size_bytes = CalculateExtentsSize(op); |
232 | // We only statically memory plan only allocates with known |
233 | // compile time sizes. |
234 | if (size_bytes.defined()) { |
235 | if (allocate_infos.find(op->buffer_var) == allocate_infos.end()) { |
236 | // By default, the core compiler is assumed to attach the a default pool to each allocate. |
237 | ICHECK(op->annotations.count(kPoolCandidatesAllocateAttr)) |
238 | << "Every statically sized allocate node needs an pool candidate attribute" ; |
239 | auto pool_candidates = |
240 | Downcast<Array<PoolInfo>>(op->annotations[kPoolCandidatesAllocateAttr]); |
241 | |
242 | ICHECK(pool_candidates.size() > 0) |
243 | << "The AssignPoolInfo pass should at least attach a single PoolInfo. If there were no " |
244 | "user-given arguments for memory pools, the default behaviour is a single size " |
245 | "un-restricted pool is assigned" ; |
246 | PrimFunc func = scope_stack_.top().func; |
247 | Optional<tvm::relay::Executor> executor_config = |
248 | module_->GetAttr<tvm::relay::Executor>(tvm::attr::kExecutor); |
249 | Integer workspace_alignment = 16; |
250 | if (executor_config) { |
251 | workspace_alignment = |
252 | executor_config.value()->GetAttr<Integer>("workspace-byte-alignment" ).value_or(16); |
253 | } |
254 | |
255 | BufferInfoKind bi_kind = BufferInfoKind::kIntermediate; |
256 | String buffer_info_name = op->buffer_var->name_hint; |
257 | if (op->annotations.find(kInputTensorAllocate) != op->annotations.end()) { |
258 | bi_kind = BufferInfoKind::kInput; |
259 | // using original input name instead of the buffer_var name |
260 | // because this name will be used in the lowering to convey |
261 | // the pool allocation. |
262 | buffer_info_name = Downcast<String>(op->annotations[kInputTensorAllocate]); |
263 | } else if (op->annotations.find(kOutputTensorAllocate) != op->annotations.end()) { |
264 | bi_kind = BufferInfoKind::kOutput; |
265 | // using original output name instead of the buffer_var name |
266 | // because this name will be used in the lowering to convey |
267 | // the pool allocation. |
268 | buffer_info_name = Downcast<String>(op->annotations[kOutputTensorAllocate]); |
269 | } |
270 | auto buffer_info = BufferInfo(GetUniqueBufferName(buffer_info_name), size_bytes, |
271 | pool_candidates, workspace_alignment, bi_kind); |
272 | auto allocate = GetRef<Allocate>(op); |
273 | allocate_infos[op->buffer_var] = |
274 | AllocateInfo{allocate, scope_stack_.top().func, scope_stack_.top().call}; |
275 | buffer_info_map_.Set(buffer_info, allocate); |
276 | } else { |
277 | // Update the allocate info with the latest call |
278 | AllocateInfo ai = allocate_infos[op->buffer_var]; |
279 | ai.call = scope_stack_.top().call; |
280 | allocate_infos[op->buffer_var] = ai; |
281 | } |
282 | } |
283 | } |
284 | |
285 | void BufferInfoExtractor::(const AllocateNode* op) { |
286 | ScopeInfo& current_scope_info = scope_stack_.top(); |
287 | const auto& type = Downcast<PointerType>(op->buffer_var->type_annotation); |
288 | const auto& storage_scope = runtime::StorageScope::Create(type->storage_scope); |
289 | |
290 | // If the allocate is in a for loop, USMP currently only looks at serial for loops. |
291 | // If its not a serial for loop, then memory planner will omit them in the current memory planning |
292 | // process leaving them to as tir.allocate nodes for codegen. Additionally, the USMP can only work |
293 | // with buffers that have global storage_scope |
294 | |
295 | if (storage_scope.rank == runtime::StorageRank::kGlobal) { |
296 | if (!current_scope_info.for_loop.defined()) { |
297 | RecordAllocateNodeInfo(op); |
298 | } else if (current_scope_info.for_loop.defined() && |
299 | current_scope_info.for_loop->kind == ForKind::kSerial) { |
300 | RecordAllocateNodeInfo(op); |
301 | } |
302 | } |
303 | StmtExprVisitor::VisitStmt(op->body); |
304 | current_scope_info.allocate_nodes.erase(GetRef<Allocate>(op)); |
305 | } |
306 | |
307 | void BufferInfoExtractor::(const AllocateConstNode* op) { |
308 | ScopeInfo& current_scope_info = scope_stack_.top(); |
309 | RecordAllocateConstNodeInfo(op); |
310 | StmtExprVisitor::VisitStmt(op->body); |
311 | current_scope_info.allocate_const_nodes.erase(GetRef<AllocateConst>(op)); |
312 | } |
313 | |
314 | void BufferInfoExtractor::(const AllocateConstNode* op) { |
315 | if (!op->annotations.count(kPoolCandidatesAllocateAttr)) { |
316 | return; |
317 | } |
318 | Integer size_bytes = CalculateExtentsSize(op); |
319 | ICHECK(size_bytes.defined()) << "constant node size should be defined" ; |
320 | const auto& buffer_var = op->buffer_var; |
321 | if (allocate_infos.find(buffer_var) == allocate_infos.end()) { |
322 | // By default, the core compiler is assumed to attach the a default pool to each allocate. |
323 | ICHECK(op->annotations.count(kPoolCandidatesAllocateAttr)) |
324 | << "Every statically sized allocate node needs an pool candidate attribute" ; |
325 | auto pool_candidates = Downcast<Array<PoolInfo>>(op->annotations[kPoolCandidatesAllocateAttr]); |
326 | ICHECK(pool_candidates.size() > 0) |
327 | << "The core compiler should at least attach a single PoolInfo. If there were no " |
328 | "user-given arguments for memory pools, the default behaviour is a single size " |
329 | "un-restricted pool is assigned" ; |
330 | PrimFunc func = scope_stack_.top().func; |
331 | Optional<tvm::relay::Executor> executor_config = |
332 | module_->GetAttr<tvm::relay::Executor>(tvm::attr::kExecutor); |
333 | Integer alignment = 16; |
334 | if (executor_config) { |
335 | alignment = |
336 | executor_config.value()->GetAttr<Integer>("constant-byte-alignment" ).value_or(alignment); |
337 | } |
338 | auto buffer_info = BufferInfo(GetUniqueBufferName(buffer_var->name_hint), size_bytes, |
339 | pool_candidates, alignment); |
340 | auto allocate = GetRef<AllocateConst>(op); |
341 | allocate_infos[buffer_var] = |
342 | AllocateInfo{allocate, scope_stack_.top().func, scope_stack_.top().call}; |
343 | buffer_info_map_.Set(buffer_info, allocate); |
344 | } else { |
345 | // Update the allocate info with the latest call |
346 | AllocateInfo ai = allocate_infos[buffer_var]; |
347 | ai.call = scope_stack_.top().call; |
348 | allocate_infos[buffer_var] = ai; |
349 | } |
350 | } |
351 | |
352 | void BufferInfoExtractor::(const ForNode* op) { |
353 | ScopeInfo si{scope_stack_.top().call, |
354 | scope_stack_.top().func, |
355 | GetRef<For>(op), |
356 | scope_stack_.top().allocate_nodes, |
357 | scope_stack_.top().allocate_const_nodes, |
358 | scope_stack_.top().initial_stmt_of_the_nested_loops}; |
359 | if (!scope_stack_.top().initial_stmt_of_the_nested_loops.defined()) { |
360 | si.initial_stmt_of_the_nested_loops = Integer(current_stmt_idx_); |
361 | } |
362 | Call current_call = scope_stack_.top().call; |
363 | PrimFunc current_primfunc = scope_stack_.top().func; |
364 | scope_stack_.push(si); |
365 | StmtExprVisitor::VisitStmt_(op); |
366 | // Extending the liveness to beginning of for-loop next and end of the current for-loop |
367 | for (const Allocate& allocate : scope_stack_.top().allocate_nodes) { |
368 | AllocateInfo ai = allocate_infos[allocate->buffer_var]; |
369 | Call update_call = current_call; |
370 | // If the allocate does not belong to current prim func |
371 | // We need to update the call to which the allocate belong to |
372 | if (ai.prim_func != current_primfunc) { |
373 | update_call = ai.call; |
374 | } |
375 | if (scope_stack_.top().initial_stmt_of_the_nested_loops->value < |
376 | buffer_info_start_stmt_idx_[update_call][allocate].IntValue()) { |
377 | buffer_info_start_stmt_idx_[update_call].Set( |
378 | allocate, scope_stack_.top().initial_stmt_of_the_nested_loops->value); |
379 | } |
380 | if (current_stmt_idx_ > buffer_info_end_stmt_idx_[update_call][allocate].IntValue()) { |
381 | buffer_info_end_stmt_idx_[update_call].Set(allocate, current_stmt_idx_); |
382 | } |
383 | } |
384 | scope_stack_.pop(); |
385 | } |
386 | |
387 | void BufferInfoExtractor::(const BufferLoadNode* op) { |
388 | this->VisitExpr(op->buffer->data); |
389 | StmtExprVisitor::VisitExpr_(op); |
390 | } |
391 | |
392 | void BufferInfoExtractor::(const BufferStoreNode* op) { |
393 | this->VisitExpr(op->buffer->data); |
394 | StmtExprVisitor::VisitStmt_(op); |
395 | } |
396 | |
397 | void BufferInfoExtractor::(const VarNode* op) { |
398 | auto var = GetRef<Var>(op); |
399 | Call current_call = scope_stack_.top().call; |
400 | PrimFunc current_primfunc = scope_stack_.top().func; |
401 | if (allocate_infos.count(var)) { |
402 | auto allocate = allocate_infos[var].Allocate; |
403 | auto allocate_primfunc = allocate_infos[var].prim_func; |
404 | Call update_call = current_call; |
405 | if (allocate_primfunc != current_primfunc) { |
406 | // If the allocate node does not belong to the current primfunc. |
407 | // It's access should be reported to the call to PrimFunc that |
408 | // Allocate belong to. |
409 | update_call = allocate_infos[var].call; |
410 | } |
411 | if (buffer_info_start_stmt_idx_[update_call].count(allocate) == 0) { |
412 | buffer_info_start_stmt_idx_[update_call].Set(allocate, current_stmt_idx_); |
413 | } |
414 | buffer_info_end_stmt_idx_[update_call].Set(allocate, current_stmt_idx_); |
415 | |
416 | ScopeInfo& currect_scope_info = scope_stack_.top(); |
417 | if (currect_scope_info.for_loop.defined()) { |
418 | if (allocate->IsInstance<AllocateNode>()) { |
419 | currect_scope_info.allocate_nodes.insert(Downcast<Allocate>(allocate)); |
420 | } else if (allocate->IsInstance<AllocateConstNode>()) { |
421 | currect_scope_info.allocate_const_nodes.insert(Downcast<AllocateConst>(allocate)); |
422 | } else { |
423 | LOG(FATAL) << "Handling of " << allocate->GetTypeKey() << " is not implemented" ; |
424 | } |
425 | } |
426 | } |
427 | StmtExprVisitor::VisitExpr_(op); |
428 | } |
429 | |
430 | Array<Var> static GetMatchedBuffers(const PrimFunc& func) { |
431 | Array<Var> buffer_vars; |
432 | if (func->params.size() > 0) { |
433 | for (unsigned int i = 0; i < func->params.size() - 1; i++) { |
434 | Var param = func->params[i]; |
435 | buffer_vars.push_back(func->buffer_map[param]->data); |
436 | } |
437 | Var last_param = func->params.back(); |
438 | // Checks whether last var is present in the buffer map |
439 | // because it could be the resource handle |
440 | if (func->buffer_map.find(last_param) != func->buffer_map.end()) { |
441 | buffer_vars.push_back(func->buffer_map[last_param]->data); |
442 | } |
443 | } |
444 | return buffer_vars; |
445 | } |
446 | |
447 | void BufferInfoExtractor::(const Array<PrimExpr>& args, const PrimFunc& func) { |
448 | auto param_buffers = GetMatchedBuffers(func); |
449 | // Last var could be a resource handle that does not have a Buffer |
450 | ICHECK(args.size() == param_buffers.size() || args.size() - 1 == param_buffers.size()); |
451 | for (size_t i = 0; i < param_buffers.size(); i++) { |
452 | auto arg = args[i]; |
453 | auto param_buf = param_buffers[i]; |
454 | // If tir.allocates are passed in to functions |
455 | // The function params are re-directed to point |
456 | // to the original allocate |
457 | if (arg->IsInstance<LoadNode>()) { |
458 | auto load = Downcast<Load>(arg); |
459 | if (allocate_infos.count(load->buffer_var)) { |
460 | allocate_infos[param_buf] = allocate_infos[load->buffer_var]; |
461 | } |
462 | } else if (arg->IsInstance<VarNode>()) { |
463 | auto var = Downcast<Var>(arg); |
464 | if (allocate_infos.count(var)) { |
465 | allocate_infos[param_buf] = allocate_infos[var]; |
466 | } |
467 | } |
468 | } |
469 | } |
470 | |
471 | void BufferInfoExtractor::(const PrimFunc& func, const Call& call) { |
472 | ScopeInfo si{call, |
473 | func, |
474 | scope_stack_.top().for_loop, |
475 | scope_stack_.top().allocate_nodes, |
476 | scope_stack_.top().allocate_const_nodes, |
477 | scope_stack_.top().initial_stmt_of_the_nested_loops}; |
478 | if (call_order_contents_.count(call) == 0) { |
479 | call_order_contents_.insert(call); |
480 | call_order_.push_back(call); |
481 | } |
482 | scope_stack_.push(si); |
483 | this->VisitStmt(func->body); |
484 | scope_stack_.pop(); |
485 | } |
486 | |
487 | void BufferInfoExtractor::(const CallNode* op) { |
488 | if (op->op.same_as(builtin::call_extern()) || op->op.same_as(builtin::tvm_call_cpacked())) { |
489 | StringImm func_name = Downcast<StringImm>(op->args[0])->value; |
490 | if (functions_.find(func_name->value) != functions_.end()) { |
491 | auto func = functions_.at(func_name->value); |
492 | auto actual_args = Array<PrimExpr>(op->args.begin() + 1, op->args.end()); |
493 | this->UpdateAliases(actual_args, func); |
494 | VisitPrimFunc(func, GetRef<Call>(op)); |
495 | return; |
496 | } |
497 | } |
498 | if (op->op->IsInstance<PrimFuncNode>()) { |
499 | auto func = Downcast<PrimFunc>(op->op); |
500 | this->UpdateAliases(op->args, func); |
501 | VisitPrimFunc(func, GetRef<Call>(op)); |
502 | return; |
503 | } |
504 | StmtExprVisitor::VisitExpr_(op); |
505 | } |
506 | |
507 | BufferInfoAnalysis BufferInfoExtractor::(const PrimFunc& main_func) { |
508 | VisitPrimFunc(main_func, Call()); |
509 | |
510 | // Create a vector of liveness events |
511 | // associated with each BufferNodes. |
512 | std::vector<LivenessEvent> le_events_timeline; |
513 | for (const auto& kv1 : buffer_info_map_) { |
514 | if (!kv1.second->IsInstance<AllocateNode>() && !kv1.second->IsInstance<AllocateConstNode>()) { |
515 | continue; |
516 | } |
517 | |
518 | auto allocate = Downcast<Stmt>(kv1.second); |
519 | auto buffer_info = Downcast<BufferInfo>(kv1.first); |
520 | |
521 | ICHECK(call_order_.size() >= buffer_info_end_stmt_idx_.size()); |
522 | ICHECK(call_order_.size() >= buffer_info_end_stmt_idx_.size()); |
523 | |
524 | for (const Call& call : call_order_) { |
525 | Map<Stmt, Integer> buffer_info_starts = buffer_info_start_stmt_idx_[call]; |
526 | if (buffer_info_starts.find(allocate) != buffer_info_starts.end()) { |
527 | LivenessEvent le_event_start; |
528 | le_event_start.buffer_info = buffer_info; |
529 | le_event_start.le_type = START; |
530 | le_event_start.tick = buffer_info_starts[allocate].IntValue(); |
531 | le_events_timeline.push_back(le_event_start); |
532 | } |
533 | } |
534 | |
535 | for (const Call& call : call_order_) { |
536 | Map<Stmt, Integer> buffer_info_ends = buffer_info_end_stmt_idx_[call]; |
537 | if (buffer_info_ends.find(allocate) != buffer_info_ends.end()) { |
538 | LivenessEvent le_event_end; |
539 | le_event_end.buffer_info = buffer_info; |
540 | le_event_end.le_type = END; |
541 | le_event_end.tick = buffer_info_ends[allocate].IntValue(); |
542 | le_events_timeline.push_back(le_event_end); |
543 | } |
544 | } |
545 | } |
546 | |
547 | // Sort the liveness events based on the chronological |
548 | // ordering. For events that are simultaneous, START event |
549 | // takes precedence. |
550 | std::sort(le_events_timeline.begin(), le_events_timeline.end(), |
551 | [](const LivenessEvent& lhs, const LivenessEvent& rhs) { |
552 | if (lhs.tick < rhs.tick) { |
553 | return true; |
554 | } else if (lhs.tick == rhs.tick && lhs.le_type == START && rhs.le_type == END) { |
555 | return true; |
556 | } |
557 | return false; |
558 | }); |
559 | |
560 | // Traverse the liveness events using a open set to track what |
561 | // is live while updating the conflicts through out the linear traversal |
562 | |
563 | int open_set_size = 0; |
564 | int max_open_set_size = 0; |
565 | std::unordered_set<BufferInfo, ObjectPtrHash, ObjectPtrEqual> open_set; |
566 | for (const auto& le_event : le_events_timeline) { |
567 | if (le_event.le_type == START) { |
568 | for (const BufferInfo& open_buffer_info : open_set) { |
569 | open_buffer_info->conflicts.push_back(le_event.buffer_info); |
570 | if (le_event.buffer_info != open_buffer_info) { |
571 | le_event.buffer_info->conflicts.push_back(open_buffer_info); |
572 | } |
573 | } |
574 | open_set_size += le_event.buffer_info->size_bytes.IntValue(); |
575 | if (open_set_size > max_open_set_size) { |
576 | max_open_set_size = open_set_size; |
577 | } |
578 | open_set.insert(le_event.buffer_info); |
579 | } else { |
580 | open_set_size -= le_event.buffer_info->size_bytes.IntValue(); |
581 | open_set.erase(le_event.buffer_info); |
582 | } |
583 | } |
584 | |
585 | // All ConstantPoolInfo items should have conflicts with each other |
586 | // as they will be placed in RO segment and pre-initialized. To achieve this |
587 | // first, split buffers to vars (WorkspacePoolInfo items) and constants (ConstantPoolInfo items): |
588 | Array<BufferInfo> buffer_info_vars; |
589 | Array<BufferInfo> buffer_info_constants; |
590 | for (const auto& kv : this->buffer_info_map_) { |
591 | const auto& stmt = kv.second; |
592 | if (stmt->IsInstance<AllocateConstNode>()) { |
593 | buffer_info_constants.push_back(kv.first); |
594 | } else { |
595 | buffer_info_vars.push_back(kv.first); |
596 | } |
597 | } |
598 | ICHECK(buffer_info_map_.size() == buffer_info_vars.size() + buffer_info_constants.size()) |
599 | << "missing value" ; |
600 | |
601 | Map<ObjectRef, ObjectRef> srch; |
602 | // Then intersect constants with each other, as all constants should exist at the same time: |
603 | for (const auto& buf : buffer_info_constants) { |
604 | srch.Set(buf, buf); |
605 | Array<ObjectRef> conflicts; |
606 | std::copy_if(buffer_info_constants.begin(), buffer_info_constants.end(), |
607 | std::back_inserter(conflicts), [buf](const auto& b) { return b != buf; }); |
608 | buf->conflicts.Assign(conflicts.begin(), conflicts.end()); |
609 | } |
610 | |
611 | // And third, remove all conflicts between constants and vars: |
612 | for (const auto& buf : buffer_info_vars) { |
613 | Array<ObjectRef> conflicts; |
614 | std::copy_if(buf->conflicts.begin(), buf->conflicts.end(), std::back_inserter(conflicts), |
615 | [&srch](const auto& c) { return srch.end() == srch.find(c); }); |
616 | buf->conflicts.Assign(conflicts.begin(), conflicts.end()); |
617 | } |
618 | return BufferInfoAnalysis(this->buffer_info_map_, max_open_set_size); |
619 | } |
620 | |
621 | BufferInfoAnalysis (const PrimFunc& main_func, const IRModule& mod) { |
622 | return BufferInfoExtractor(mod)(main_func); |
623 | } |
624 | |
625 | TVM_REGISTER_GLOBAL("tir.usmp.analysis.extract_buffer_info" ) |
626 | .set_body_typed([](PrimFunc main_func, IRModule mod) { |
627 | return (ExtractBufferInfo(main_func, mod)); |
628 | }); |
629 | |
630 | } // namespace usmp |
631 | } // namespace tir |
632 | } // namespace tvm |
633 | |