1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tir/analysis/usmp/extract_buffer_info.cc
22 *
23 * \brief This analysis pass consumes a TIR IRModule with a main function
24 * that defines a ordering in the callees to operators and produces BufferInfo
25 * objects that contains information about tir.allocate nodes and liveness
26 * conflicts between other tir.allocate nodes.
27 */
28#include <tvm/arith/analyzer.h>
29#include <tvm/relay/executor.h>
30#include <tvm/runtime/device_api.h>
31#include <tvm/tir/builtin.h>
32#include <tvm/tir/function.h>
33#include <tvm/tir/stmt_functor.h>
34#include <tvm/tir/usmp/analysis.h>
35#include <tvm/tir/usmp/utils.h>
36
37#include <stack>
38
39#include "../../../runtime/thread_storage_scope.h"
40
41namespace tvm {
42namespace tir {
43namespace usmp {
44
45/*!
46 * \brief The visitor class to obtain buffer information
47 *
48 * The visitor would initiate the traversal from the main
49 * function and visits into the operator PrimFuncs. It will
50 * crate unique BufferInfo objects for each Allocate node.
51 *
52 * Every time the buffer variable of the allocate node is referenced
53 * it will be recorded using the stmt index. However, note that
54 * the same buffer variable could be references multiple times
55 * from different calls. Thereafter, a sweep is done on all the
56 * BufferInfo objects using the per-call liveness events. In the sweep,
57 * The BufferInfo objects that are live together will be recorded as
58 * mutual conflicts of each other.
59 */
60class BufferInfoExtractor : public StmtExprVisitor {
61 public:
62 explicit BufferInfoExtractor(const IRModule& module) : module_(module) {
63 for (const auto& gv_func : module_->functions) {
64 if (gv_func.second->IsInstance<PrimFuncNode>()) {
65 functions_.Set(gv_func.first->name_hint, Downcast<PrimFunc>(gv_func.second));
66 }
67 }
68 // Pushing a scope info for the initial body of the main function
69 scope_stack_.push(ScopeInfo());
70 }
71 BufferInfoAnalysis operator()(const PrimFunc& func);
72
73 private:
74 void VisitStmt(const Stmt& n) override;
75 void VisitStmt_(const AllocateNode* op) override;
76 void VisitStmt_(const AllocateConstNode* op) override;
77 void VisitExpr_(const CallNode* op) override;
78 void VisitExpr_(const VarNode* op) override;
79 void VisitExpr_(const BufferLoadNode* op) override;
80 void VisitStmt_(const BufferStoreNode* op) override;
81 void VisitStmt_(const ForNode* op) override;
82
83 void UpdateAliases(const Array<PrimExpr>& args, const PrimFunc& func);
84 void RecordAllocateNodeInfo(const AllocateNode* op);
85 void RecordAllocateConstNodeInfo(const AllocateConstNode* op);
86 void VisitPrimFunc(const PrimFunc& func, const Call& call);
87
88 /*!
89 * \brief Maintains the mapping of BufferInfo to their associated TIR Statements.
90 */
91 Map<BufferInfo, tir::Stmt> buffer_info_map_;
92 /*!
93 * \brief Records the order of calls in the main for stability.
94 */
95 std::vector<Call> call_order_;
96 /*!
97 * \brief Lookup to avoid adding duplicates to `call_order_`.
98 */
99 std::unordered_set<Call, ObjectPtrHash, ObjectPtrEqual> call_order_contents_;
100 /*!
101 * \brief Records first access in-terms of Stmts to each buffer per call
102 *
103 * This is because multiple calls could happen to the same PrimFunc.
104 */
105 std::unordered_map<Call, Map<tir::Stmt, Integer>, ObjectPtrHash, ObjectPtrEqual>
106 buffer_info_start_stmt_idx_;
107 /*!
108 * \brief Records last access in-terms of Stmts to each buffer per call
109 *
110 * This is because multiple calls could happen to the same PrimFunc.
111 */
112 std::unordered_map<Call, Map<tir::Stmt, Integer>, ObjectPtrHash, ObjectPtrEqual>
113 buffer_info_end_stmt_idx_;
114
115 /*!
116 * \brief This structure contains information regarding a Allocate node.
117 */
118 struct AllocateInfo {
119 tir::Stmt Allocate;
120 PrimFunc prim_func;
121 Call call;
122 };
123
124 /*!
125 * \brief Maintains the mapping of buffer variable to their allocate nodes to ensure
126 * that only one BufferInfo object is created.
127 */
128 std::unordered_map<tir::Var, AllocateInfo, ObjectPtrHash, ObjectPtrEqual> allocate_infos;
129 /*!
130 * \brief Indicates a count of stmts visited so far to use as a metric of liveness
131 */
132 int current_stmt_idx_ = 0;
133 /*!
134 * \brief This structure is supposed to contain information around the scope
135 * the visitor is currently in.
136 */
137 struct ScopeInfo {
138 /*!
139 * \brief We need to record access per call
140 */
141 Call call;
142 /*!
143 * \brief Having access to PrimFunc metadata is useful
144 */
145 PrimFunc func;
146 /*!
147 * \brief We currently support only serial for loops. Therefore
148 * need to know what kind of for loop the visitor is in.
149 */
150 For for_loop;
151 /*!
152 * \brief We record the live allocate_nodes because once in loops
153 * the liveness range has to be extended to the whole of the nested
154 * loops structure.
155 */
156 std::unordered_set<Allocate, ObjectPtrHash, ObjectPtrEqual> allocate_nodes;
157 /*
158 * \brief We record the live allocate_const_nodes because once in loops
159 * the liveness range has to be extended to the whole of the nested
160 * loops structure.
161 */
162 std::unordered_set<AllocateConst, ObjectPtrHash, ObjectPtrEqual> allocate_const_nodes;
163 /*!
164 * \brief This is recorded to extend the liveness of all allocates within
165 * nested loop structure.
166 */
167 Integer initial_stmt_of_the_nested_loops;
168 };
169 std::stack<ScopeInfo> scope_stack_;
170
171 /*!
172 * \brief A liveness event is an event that when
173 * traversing the tir.Stmts where tir.allocate node
174 * begins or ceases to be Live. This particular struct
175 * is used to solve interval overlap problem using
176 * a sweep-line algorithm. For that, we need to record
177 * where the liveness event occurred in a chronological
178 * order.
179 */
180 enum LivenessEventType { START = 0, END = 1 };
181 struct LivenessEvent {
182 size_t tick;
183 LivenessEventType le_type;
184 BufferInfo buffer_info;
185 bool operator==(const LivenessEvent& other) {
186 if (tick == other.tick && le_type == other.le_type && buffer_info == other.buffer_info) {
187 return true;
188 }
189 return false;
190 }
191 };
192 /*!
193 * \brief We need to create unique buffer name is the same name is used in
194 * two allocate nodes for clarity for memory planning algorithms.
195 */
196 std::string GetUniqueBufferName(std::string name);
197
198 /*!
199 * \brief This is per buffer name counter to aid the generating the above
200 * unique name.
201 */
202 std::unordered_map<std::string, int> buffer_names;
203 /*!
204 * \brief The TIR main function calls by name to PrimFuncs to be able to
205 * support BYOC. Therefore, this Map records functions that are present
206 * in the IRModule by name/
207 */
208 Map<String, PrimFunc> functions_;
209 /*!
210 * \brief The IRModule being analyzed.
211 */
212 IRModule module_;
213};
214
215std::string BufferInfoExtractor::GetUniqueBufferName(std::string name) {
216 if (buffer_names.find(name) == buffer_names.end()) {
217 buffer_names[name] = 1;
218 return name;
219 } else {
220 buffer_names[name] = buffer_names[name] + 1;
221 return name + std::to_string(buffer_names[name]);
222 }
223}
224
225void BufferInfoExtractor::VisitStmt(const Stmt& n) {
226 current_stmt_idx_ += 1;
227 StmtExprVisitor::VisitStmt(n);
228}
229
230void BufferInfoExtractor::RecordAllocateNodeInfo(const AllocateNode* op) {
231 auto size_bytes = CalculateExtentsSize(op);
232 // We only statically memory plan only allocates with known
233 // compile time sizes.
234 if (size_bytes.defined()) {
235 if (allocate_infos.find(op->buffer_var) == allocate_infos.end()) {
236 // By default, the core compiler is assumed to attach the a default pool to each allocate.
237 ICHECK(op->annotations.count(kPoolCandidatesAllocateAttr))
238 << "Every statically sized allocate node needs an pool candidate attribute";
239 auto pool_candidates =
240 Downcast<Array<PoolInfo>>(op->annotations[kPoolCandidatesAllocateAttr]);
241
242 ICHECK(pool_candidates.size() > 0)
243 << "The AssignPoolInfo pass should at least attach a single PoolInfo. If there were no "
244 "user-given arguments for memory pools, the default behaviour is a single size "
245 "un-restricted pool is assigned";
246 PrimFunc func = scope_stack_.top().func;
247 Optional<tvm::relay::Executor> executor_config =
248 module_->GetAttr<tvm::relay::Executor>(tvm::attr::kExecutor);
249 Integer workspace_alignment = 16;
250 if (executor_config) {
251 workspace_alignment =
252 executor_config.value()->GetAttr<Integer>("workspace-byte-alignment").value_or(16);
253 }
254
255 BufferInfoKind bi_kind = BufferInfoKind::kIntermediate;
256 String buffer_info_name = op->buffer_var->name_hint;
257 if (op->annotations.find(kInputTensorAllocate) != op->annotations.end()) {
258 bi_kind = BufferInfoKind::kInput;
259 // using original input name instead of the buffer_var name
260 // because this name will be used in the lowering to convey
261 // the pool allocation.
262 buffer_info_name = Downcast<String>(op->annotations[kInputTensorAllocate]);
263 } else if (op->annotations.find(kOutputTensorAllocate) != op->annotations.end()) {
264 bi_kind = BufferInfoKind::kOutput;
265 // using original output name instead of the buffer_var name
266 // because this name will be used in the lowering to convey
267 // the pool allocation.
268 buffer_info_name = Downcast<String>(op->annotations[kOutputTensorAllocate]);
269 }
270 auto buffer_info = BufferInfo(GetUniqueBufferName(buffer_info_name), size_bytes,
271 pool_candidates, workspace_alignment, bi_kind);
272 auto allocate = GetRef<Allocate>(op);
273 allocate_infos[op->buffer_var] =
274 AllocateInfo{allocate, scope_stack_.top().func, scope_stack_.top().call};
275 buffer_info_map_.Set(buffer_info, allocate);
276 } else {
277 // Update the allocate info with the latest call
278 AllocateInfo ai = allocate_infos[op->buffer_var];
279 ai.call = scope_stack_.top().call;
280 allocate_infos[op->buffer_var] = ai;
281 }
282 }
283}
284
285void BufferInfoExtractor::VisitStmt_(const AllocateNode* op) {
286 ScopeInfo& current_scope_info = scope_stack_.top();
287 const auto& type = Downcast<PointerType>(op->buffer_var->type_annotation);
288 const auto& storage_scope = runtime::StorageScope::Create(type->storage_scope);
289
290 // If the allocate is in a for loop, USMP currently only looks at serial for loops.
291 // If its not a serial for loop, then memory planner will omit them in the current memory planning
292 // process leaving them to as tir.allocate nodes for codegen. Additionally, the USMP can only work
293 // with buffers that have global storage_scope
294
295 if (storage_scope.rank == runtime::StorageRank::kGlobal) {
296 if (!current_scope_info.for_loop.defined()) {
297 RecordAllocateNodeInfo(op);
298 } else if (current_scope_info.for_loop.defined() &&
299 current_scope_info.for_loop->kind == ForKind::kSerial) {
300 RecordAllocateNodeInfo(op);
301 }
302 }
303 StmtExprVisitor::VisitStmt(op->body);
304 current_scope_info.allocate_nodes.erase(GetRef<Allocate>(op));
305}
306
307void BufferInfoExtractor::VisitStmt_(const AllocateConstNode* op) {
308 ScopeInfo& current_scope_info = scope_stack_.top();
309 RecordAllocateConstNodeInfo(op);
310 StmtExprVisitor::VisitStmt(op->body);
311 current_scope_info.allocate_const_nodes.erase(GetRef<AllocateConst>(op));
312}
313
314void BufferInfoExtractor::RecordAllocateConstNodeInfo(const AllocateConstNode* op) {
315 if (!op->annotations.count(kPoolCandidatesAllocateAttr)) {
316 return;
317 }
318 Integer size_bytes = CalculateExtentsSize(op);
319 ICHECK(size_bytes.defined()) << "constant node size should be defined";
320 const auto& buffer_var = op->buffer_var;
321 if (allocate_infos.find(buffer_var) == allocate_infos.end()) {
322 // By default, the core compiler is assumed to attach the a default pool to each allocate.
323 ICHECK(op->annotations.count(kPoolCandidatesAllocateAttr))
324 << "Every statically sized allocate node needs an pool candidate attribute";
325 auto pool_candidates = Downcast<Array<PoolInfo>>(op->annotations[kPoolCandidatesAllocateAttr]);
326 ICHECK(pool_candidates.size() > 0)
327 << "The core compiler should at least attach a single PoolInfo. If there were no "
328 "user-given arguments for memory pools, the default behaviour is a single size "
329 "un-restricted pool is assigned";
330 PrimFunc func = scope_stack_.top().func;
331 Optional<tvm::relay::Executor> executor_config =
332 module_->GetAttr<tvm::relay::Executor>(tvm::attr::kExecutor);
333 Integer alignment = 16;
334 if (executor_config) {
335 alignment =
336 executor_config.value()->GetAttr<Integer>("constant-byte-alignment").value_or(alignment);
337 }
338 auto buffer_info = BufferInfo(GetUniqueBufferName(buffer_var->name_hint), size_bytes,
339 pool_candidates, alignment);
340 auto allocate = GetRef<AllocateConst>(op);
341 allocate_infos[buffer_var] =
342 AllocateInfo{allocate, scope_stack_.top().func, scope_stack_.top().call};
343 buffer_info_map_.Set(buffer_info, allocate);
344 } else {
345 // Update the allocate info with the latest call
346 AllocateInfo ai = allocate_infos[buffer_var];
347 ai.call = scope_stack_.top().call;
348 allocate_infos[buffer_var] = ai;
349 }
350}
351
352void BufferInfoExtractor::VisitStmt_(const ForNode* op) {
353 ScopeInfo si{scope_stack_.top().call,
354 scope_stack_.top().func,
355 GetRef<For>(op),
356 scope_stack_.top().allocate_nodes,
357 scope_stack_.top().allocate_const_nodes,
358 scope_stack_.top().initial_stmt_of_the_nested_loops};
359 if (!scope_stack_.top().initial_stmt_of_the_nested_loops.defined()) {
360 si.initial_stmt_of_the_nested_loops = Integer(current_stmt_idx_);
361 }
362 Call current_call = scope_stack_.top().call;
363 PrimFunc current_primfunc = scope_stack_.top().func;
364 scope_stack_.push(si);
365 StmtExprVisitor::VisitStmt_(op);
366 // Extending the liveness to beginning of for-loop next and end of the current for-loop
367 for (const Allocate& allocate : scope_stack_.top().allocate_nodes) {
368 AllocateInfo ai = allocate_infos[allocate->buffer_var];
369 Call update_call = current_call;
370 // If the allocate does not belong to current prim func
371 // We need to update the call to which the allocate belong to
372 if (ai.prim_func != current_primfunc) {
373 update_call = ai.call;
374 }
375 if (scope_stack_.top().initial_stmt_of_the_nested_loops->value <
376 buffer_info_start_stmt_idx_[update_call][allocate].IntValue()) {
377 buffer_info_start_stmt_idx_[update_call].Set(
378 allocate, scope_stack_.top().initial_stmt_of_the_nested_loops->value);
379 }
380 if (current_stmt_idx_ > buffer_info_end_stmt_idx_[update_call][allocate].IntValue()) {
381 buffer_info_end_stmt_idx_[update_call].Set(allocate, current_stmt_idx_);
382 }
383 }
384 scope_stack_.pop();
385}
386
387void BufferInfoExtractor::VisitExpr_(const BufferLoadNode* op) {
388 this->VisitExpr(op->buffer->data);
389 StmtExprVisitor::VisitExpr_(op);
390}
391
392void BufferInfoExtractor::VisitStmt_(const BufferStoreNode* op) {
393 this->VisitExpr(op->buffer->data);
394 StmtExprVisitor::VisitStmt_(op);
395}
396
397void BufferInfoExtractor::VisitExpr_(const VarNode* op) {
398 auto var = GetRef<Var>(op);
399 Call current_call = scope_stack_.top().call;
400 PrimFunc current_primfunc = scope_stack_.top().func;
401 if (allocate_infos.count(var)) {
402 auto allocate = allocate_infos[var].Allocate;
403 auto allocate_primfunc = allocate_infos[var].prim_func;
404 Call update_call = current_call;
405 if (allocate_primfunc != current_primfunc) {
406 // If the allocate node does not belong to the current primfunc.
407 // It's access should be reported to the call to PrimFunc that
408 // Allocate belong to.
409 update_call = allocate_infos[var].call;
410 }
411 if (buffer_info_start_stmt_idx_[update_call].count(allocate) == 0) {
412 buffer_info_start_stmt_idx_[update_call].Set(allocate, current_stmt_idx_);
413 }
414 buffer_info_end_stmt_idx_[update_call].Set(allocate, current_stmt_idx_);
415
416 ScopeInfo& currect_scope_info = scope_stack_.top();
417 if (currect_scope_info.for_loop.defined()) {
418 if (allocate->IsInstance<AllocateNode>()) {
419 currect_scope_info.allocate_nodes.insert(Downcast<Allocate>(allocate));
420 } else if (allocate->IsInstance<AllocateConstNode>()) {
421 currect_scope_info.allocate_const_nodes.insert(Downcast<AllocateConst>(allocate));
422 } else {
423 LOG(FATAL) << "Handling of " << allocate->GetTypeKey() << " is not implemented";
424 }
425 }
426 }
427 StmtExprVisitor::VisitExpr_(op);
428}
429
430Array<Var> static GetMatchedBuffers(const PrimFunc& func) {
431 Array<Var> buffer_vars;
432 if (func->params.size() > 0) {
433 for (unsigned int i = 0; i < func->params.size() - 1; i++) {
434 Var param = func->params[i];
435 buffer_vars.push_back(func->buffer_map[param]->data);
436 }
437 Var last_param = func->params.back();
438 // Checks whether last var is present in the buffer map
439 // because it could be the resource handle
440 if (func->buffer_map.find(last_param) != func->buffer_map.end()) {
441 buffer_vars.push_back(func->buffer_map[last_param]->data);
442 }
443 }
444 return buffer_vars;
445}
446
447void BufferInfoExtractor::UpdateAliases(const Array<PrimExpr>& args, const PrimFunc& func) {
448 auto param_buffers = GetMatchedBuffers(func);
449 // Last var could be a resource handle that does not have a Buffer
450 ICHECK(args.size() == param_buffers.size() || args.size() - 1 == param_buffers.size());
451 for (size_t i = 0; i < param_buffers.size(); i++) {
452 auto arg = args[i];
453 auto param_buf = param_buffers[i];
454 // If tir.allocates are passed in to functions
455 // The function params are re-directed to point
456 // to the original allocate
457 if (arg->IsInstance<LoadNode>()) {
458 auto load = Downcast<Load>(arg);
459 if (allocate_infos.count(load->buffer_var)) {
460 allocate_infos[param_buf] = allocate_infos[load->buffer_var];
461 }
462 } else if (arg->IsInstance<VarNode>()) {
463 auto var = Downcast<Var>(arg);
464 if (allocate_infos.count(var)) {
465 allocate_infos[param_buf] = allocate_infos[var];
466 }
467 }
468 }
469}
470
471void BufferInfoExtractor::VisitPrimFunc(const PrimFunc& func, const Call& call) {
472 ScopeInfo si{call,
473 func,
474 scope_stack_.top().for_loop,
475 scope_stack_.top().allocate_nodes,
476 scope_stack_.top().allocate_const_nodes,
477 scope_stack_.top().initial_stmt_of_the_nested_loops};
478 if (call_order_contents_.count(call) == 0) {
479 call_order_contents_.insert(call);
480 call_order_.push_back(call);
481 }
482 scope_stack_.push(si);
483 this->VisitStmt(func->body);
484 scope_stack_.pop();
485}
486
487void BufferInfoExtractor::VisitExpr_(const CallNode* op) {
488 if (op->op.same_as(builtin::call_extern()) || op->op.same_as(builtin::tvm_call_cpacked())) {
489 StringImm func_name = Downcast<StringImm>(op->args[0])->value;
490 if (functions_.find(func_name->value) != functions_.end()) {
491 auto func = functions_.at(func_name->value);
492 auto actual_args = Array<PrimExpr>(op->args.begin() + 1, op->args.end());
493 this->UpdateAliases(actual_args, func);
494 VisitPrimFunc(func, GetRef<Call>(op));
495 return;
496 }
497 }
498 if (op->op->IsInstance<PrimFuncNode>()) {
499 auto func = Downcast<PrimFunc>(op->op);
500 this->UpdateAliases(op->args, func);
501 VisitPrimFunc(func, GetRef<Call>(op));
502 return;
503 }
504 StmtExprVisitor::VisitExpr_(op);
505}
506
507BufferInfoAnalysis BufferInfoExtractor::operator()(const PrimFunc& main_func) {
508 VisitPrimFunc(main_func, Call());
509
510 // Create a vector of liveness events
511 // associated with each BufferNodes.
512 std::vector<LivenessEvent> le_events_timeline;
513 for (const auto& kv1 : buffer_info_map_) {
514 if (!kv1.second->IsInstance<AllocateNode>() && !kv1.second->IsInstance<AllocateConstNode>()) {
515 continue;
516 }
517
518 auto allocate = Downcast<Stmt>(kv1.second);
519 auto buffer_info = Downcast<BufferInfo>(kv1.first);
520
521 ICHECK(call_order_.size() >= buffer_info_end_stmt_idx_.size());
522 ICHECK(call_order_.size() >= buffer_info_end_stmt_idx_.size());
523
524 for (const Call& call : call_order_) {
525 Map<Stmt, Integer> buffer_info_starts = buffer_info_start_stmt_idx_[call];
526 if (buffer_info_starts.find(allocate) != buffer_info_starts.end()) {
527 LivenessEvent le_event_start;
528 le_event_start.buffer_info = buffer_info;
529 le_event_start.le_type = START;
530 le_event_start.tick = buffer_info_starts[allocate].IntValue();
531 le_events_timeline.push_back(le_event_start);
532 }
533 }
534
535 for (const Call& call : call_order_) {
536 Map<Stmt, Integer> buffer_info_ends = buffer_info_end_stmt_idx_[call];
537 if (buffer_info_ends.find(allocate) != buffer_info_ends.end()) {
538 LivenessEvent le_event_end;
539 le_event_end.buffer_info = buffer_info;
540 le_event_end.le_type = END;
541 le_event_end.tick = buffer_info_ends[allocate].IntValue();
542 le_events_timeline.push_back(le_event_end);
543 }
544 }
545 }
546
547 // Sort the liveness events based on the chronological
548 // ordering. For events that are simultaneous, START event
549 // takes precedence.
550 std::sort(le_events_timeline.begin(), le_events_timeline.end(),
551 [](const LivenessEvent& lhs, const LivenessEvent& rhs) {
552 if (lhs.tick < rhs.tick) {
553 return true;
554 } else if (lhs.tick == rhs.tick && lhs.le_type == START && rhs.le_type == END) {
555 return true;
556 }
557 return false;
558 });
559
560 // Traverse the liveness events using a open set to track what
561 // is live while updating the conflicts through out the linear traversal
562
563 int open_set_size = 0;
564 int max_open_set_size = 0;
565 std::unordered_set<BufferInfo, ObjectPtrHash, ObjectPtrEqual> open_set;
566 for (const auto& le_event : le_events_timeline) {
567 if (le_event.le_type == START) {
568 for (const BufferInfo& open_buffer_info : open_set) {
569 open_buffer_info->conflicts.push_back(le_event.buffer_info);
570 if (le_event.buffer_info != open_buffer_info) {
571 le_event.buffer_info->conflicts.push_back(open_buffer_info);
572 }
573 }
574 open_set_size += le_event.buffer_info->size_bytes.IntValue();
575 if (open_set_size > max_open_set_size) {
576 max_open_set_size = open_set_size;
577 }
578 open_set.insert(le_event.buffer_info);
579 } else {
580 open_set_size -= le_event.buffer_info->size_bytes.IntValue();
581 open_set.erase(le_event.buffer_info);
582 }
583 }
584
585 // All ConstantPoolInfo items should have conflicts with each other
586 // as they will be placed in RO segment and pre-initialized. To achieve this
587 // first, split buffers to vars (WorkspacePoolInfo items) and constants (ConstantPoolInfo items):
588 Array<BufferInfo> buffer_info_vars;
589 Array<BufferInfo> buffer_info_constants;
590 for (const auto& kv : this->buffer_info_map_) {
591 const auto& stmt = kv.second;
592 if (stmt->IsInstance<AllocateConstNode>()) {
593 buffer_info_constants.push_back(kv.first);
594 } else {
595 buffer_info_vars.push_back(kv.first);
596 }
597 }
598 ICHECK(buffer_info_map_.size() == buffer_info_vars.size() + buffer_info_constants.size())
599 << "missing value";
600
601 Map<ObjectRef, ObjectRef> srch;
602 // Then intersect constants with each other, as all constants should exist at the same time:
603 for (const auto& buf : buffer_info_constants) {
604 srch.Set(buf, buf);
605 Array<ObjectRef> conflicts;
606 std::copy_if(buffer_info_constants.begin(), buffer_info_constants.end(),
607 std::back_inserter(conflicts), [buf](const auto& b) { return b != buf; });
608 buf->conflicts.Assign(conflicts.begin(), conflicts.end());
609 }
610
611 // And third, remove all conflicts between constants and vars:
612 for (const auto& buf : buffer_info_vars) {
613 Array<ObjectRef> conflicts;
614 std::copy_if(buf->conflicts.begin(), buf->conflicts.end(), std::back_inserter(conflicts),
615 [&srch](const auto& c) { return srch.end() == srch.find(c); });
616 buf->conflicts.Assign(conflicts.begin(), conflicts.end());
617 }
618 return BufferInfoAnalysis(this->buffer_info_map_, max_open_set_size);
619}
620
621BufferInfoAnalysis ExtractBufferInfo(const PrimFunc& main_func, const IRModule& mod) {
622 return BufferInfoExtractor(mod)(main_func);
623}
624
625TVM_REGISTER_GLOBAL("tir.usmp.analysis.extract_buffer_info")
626 .set_body_typed([](PrimFunc main_func, IRModule mod) {
627 return (ExtractBufferInfo(main_func, mod));
628 });
629
630} // namespace usmp
631} // namespace tir
632} // namespace tvm
633