1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tir/analysis/buffer_access_lca_detector.cc |
22 | * \brief Detect the lowest common ancestor(LCA) of buffer access |
23 | */ |
24 | |
25 | #include <tvm/tir/analysis.h> |
26 | #include <tvm/tir/stmt_functor.h> |
27 | |
28 | #include "../../runtime/thread_storage_scope.h" |
29 | #include "../../support/arena.h" |
30 | |
31 | namespace tvm { |
32 | namespace tir { |
33 | |
34 | /*! |
35 | * \brief Detect the lowest common ancestor(LCA) position of Buffer access. |
36 | * \note |
37 | * - Only consider BlockNode and ForNode to be the LCA nodes. |
38 | * - In the LCA locator, we are aware of the buffer scope and CUDA hierarchy so that any buffer in |
39 | * global memory will have its buffer access LCA outside all launch sites of `blockIdx`, in order to |
40 | * prevent conflicts between buffer memory scopes and CUDA hierarchy. |
41 | */ |
42 | class LCADetector : public StmtExprVisitor { |
43 | public: |
44 | static Map<Buffer, Optional<Stmt>> Detect(const PrimFunc& func) { |
45 | LCADetector detector; |
46 | for (const auto& kv : func->buffer_map) { |
47 | const Buffer& buffer = kv.second; |
48 | detector.buffer_var_map_.emplace(buffer->data.get(), buffer.get()); |
49 | } |
50 | |
51 | // The root node must be explicitly present in the list of |
52 | // ancestor_scopes_. We cannot use nullptr to represent the root |
53 | // node, as that is also used to represent a scope that hasn't |
54 | // been observed before. |
55 | ScopeInfo root(nullptr, nullptr, 0); |
56 | detector.ancestor_scopes_.push_back(&root); |
57 | |
58 | detector(func->body); |
59 | detector.UpdateWithBlockidx(); |
60 | |
61 | // Prepare the return |
62 | Map<Buffer, Optional<Stmt>> buffer_lca; |
63 | for (const auto& kv : detector.buffer_lca_) { |
64 | const Buffer& buffer = GetRef<Buffer>(kv.first); |
65 | const Optional<Stmt> stmt = kv.second ? GetRef<Optional<Stmt>>(kv.second->stmt) : NullOpt; |
66 | buffer_lca.Set(buffer, stmt); |
67 | } |
68 | return buffer_lca; |
69 | } |
70 | |
71 | private: |
72 | /*! |
73 | * \brief The AST node information for querying LCA. |
74 | * \note Only BlockNode and ForNode are considered, since they are the only statements whose |
75 | * body can be a SeqStmt (the LCA of buffer access) in TensorIR. |
76 | */ |
77 | struct ScopeInfo { |
78 | // The parent scope info |
79 | const ScopeInfo* parent_scope_info; |
80 | // The parent scope stmt node |
81 | const StmtNode* stmt; |
82 | // The scope depth in the AST |
83 | int depth; |
84 | ScopeInfo(const ScopeInfo* parent_info, const StmtNode* stmt, int depth) |
85 | : parent_scope_info(parent_info), stmt(stmt), depth(depth) {} |
86 | }; |
87 | |
88 | void VisitStmt_(const ForNode* op) final { |
89 | int n = ancestor_scopes_.size(); |
90 | const ScopeInfo* parent_scope = ancestor_scopes_.back(); |
91 | auto* current_scope = arena_.make<ScopeInfo>(parent_scope, op, n); |
92 | |
93 | if (op->thread_binding.defined()) { |
94 | const runtime::ThreadScope& scope = |
95 | runtime::ThreadScope::Create(op->thread_binding.value()->thread_tag); |
96 | if (scope.rank == 0) { |
97 | blockidx_scopes_.push_back(current_scope); |
98 | } |
99 | } |
100 | |
101 | ancestor_scopes_.push_back(current_scope); |
102 | loop_scope_map_.insert({op->loop_var.get(), current_scope}); |
103 | StmtExprVisitor::VisitStmt_(op); |
104 | ancestor_scopes_.pop_back(); |
105 | loop_scope_map_.erase(op->loop_var.get()); |
106 | } |
107 | |
108 | void VisitStmt_(const BlockRealizeNode* op) final { |
109 | const BlockNode* block = op->block.get(); |
110 | int n = ancestor_scopes_.size(); |
111 | for (const Buffer& buf : block->alloc_buffers) { |
112 | buffer_var_map_.emplace(buf->data.get(), buf.get()); |
113 | } |
114 | |
115 | const ScopeInfo* parent_scope = ancestor_scopes_.back(); |
116 | auto* current_scope = arena_.make<ScopeInfo>(parent_scope, block, n); |
117 | |
118 | ancestor_scopes_.push_back(current_scope); |
119 | |
120 | // For each accessed buffer of the block, update the buffer's lca to |
121 | // the lowest inclusive stmt position, which should dominate all loops |
122 | // related to the accessed opaque block iter vars in buffer indices. |
123 | UpdateDominateScopeOfOpaqueIter(op); |
124 | |
125 | // Update match_buffers |
126 | for (const MatchBufferRegion& match_buffer : block->match_buffers) { |
127 | UpdateBufferLCA(match_buffer->source->buffer.get(), ancestor_scopes_.back()); |
128 | match_buffers_.insert(match_buffer->buffer.get()); |
129 | } |
130 | |
131 | StmtExprVisitor::VisitStmt_(op); |
132 | ancestor_scopes_.pop_back(); |
133 | } |
134 | |
135 | void UpdateDominateScopeOfOpaqueIter(const BlockRealizeNode* block_realize) { |
136 | // map opaque iter var to the scope which dominate all loop carried dependencies. |
137 | std::unordered_map<const VarNode*, const ScopeInfo*> itervar_to_dom_scope; |
138 | |
139 | // function to collect `itervar_to_dom_scope`, the result scope for each block |
140 | // iter var should be above all loop scopes the opaque iter var binding relates to. |
141 | auto do_collect_itervar_scope = [this, &itervar_to_dom_scope](const IterVar& itervar, |
142 | const PrimExpr& binding) { |
143 | PostOrderVisit(binding, [this, &itervar_to_dom_scope, &itervar](const ObjectRef& obj) { |
144 | if (const VarNode* loop_var = obj.as<VarNode>()) { |
145 | auto it = loop_scope_map_.find(loop_var); |
146 | if (it == loop_scope_map_.end()) { |
147 | return; |
148 | } |
149 | const ScopeInfo* scope = it->second->parent_scope_info; |
150 | // find the highest loop scope the iter var binding has related to. |
151 | auto dom_scope_it = itervar_to_dom_scope.find(itervar->var.get()); |
152 | if (dom_scope_it == itervar_to_dom_scope.end()) { |
153 | itervar_to_dom_scope.insert(dom_scope_it, {itervar->var.get(), scope}); |
154 | } else if (scope->depth < dom_scope_it->second->depth) { |
155 | dom_scope_it->second = scope; |
156 | } |
157 | } |
158 | }); |
159 | }; |
160 | |
161 | // function to update lca scope of the buffer with loop carried dependent buffer accesses. |
162 | // the result scope should be above all loop scopes the accessed opaque block iter vars |
163 | // relate to, which is record in `itervar_to_dom_scope`. |
164 | auto do_update = [this, &itervar_to_dom_scope](const BufferRegion& region) { |
165 | const Buffer& buffer = region->buffer; |
166 | const ScopeInfo* scope = ancestor_scopes_.back(); |
167 | |
168 | auto handle_itervar = [&itervar_to_dom_scope, &scope](const ObjectRef& obj) { |
169 | if (const VarNode* iter_var = obj.as<VarNode>()) { |
170 | auto dom_scope_it = itervar_to_dom_scope.find(iter_var); |
171 | if (dom_scope_it == itervar_to_dom_scope.end()) { |
172 | return; |
173 | } |
174 | // find the highest loop scope the accessed buffer index has |
175 | // loop carried dependencies to (via opaque iter var binding). |
176 | if (dom_scope_it->second->depth < scope->depth) { |
177 | scope = dom_scope_it->second; |
178 | } |
179 | } |
180 | }; |
181 | |
182 | // visit region min and max to find the lowest legal lca scope |
183 | for (const Range& range : region->region) { |
184 | PostOrderVisit(range->min, handle_itervar); |
185 | PostOrderVisit(range->min + range->extent - 1, handle_itervar); |
186 | } |
187 | UpdateBufferLCA(buffer.get(), scope); |
188 | }; |
189 | |
190 | // do collect and update |
191 | const Block& block = block_realize->block; |
192 | for (size_t i = 0; i < block_realize->iter_values.size(); ++i) { |
193 | const IterVar& iter_var = block->iter_vars[i]; |
194 | if (iter_var->iter_type != IterVarType::kDataPar && |
195 | iter_var->iter_type != IterVarType::kCommReduce) { |
196 | do_collect_itervar_scope(iter_var, block_realize->iter_values[i]); |
197 | } |
198 | } |
199 | if (!itervar_to_dom_scope.empty()) { |
200 | for (const auto& read : block->reads) { |
201 | do_update(read); |
202 | } |
203 | for (const auto& write : block->writes) { |
204 | do_update(write); |
205 | } |
206 | } |
207 | } |
208 | |
209 | void VisitStmt_(const AttrStmtNode* op) final { |
210 | if (op->attr_key == attr::thread_extent) { |
211 | const auto* iter = op->node.as<IterVarNode>(); |
212 | ICHECK_NOTNULL(iter); |
213 | const runtime::ThreadScope& scope = runtime::ThreadScope::Create(iter->thread_tag); |
214 | if (scope.rank == 0) { |
215 | blockidx_scopes_.push_back(ancestor_scopes_.back()); |
216 | } |
217 | } |
218 | StmtExprVisitor::VisitStmt_(op); |
219 | } |
220 | |
221 | void VisitExpr_(const BufferLoadNode* op) final { |
222 | UpdateBufferLCA(op->buffer.get(), ancestor_scopes_.back()); |
223 | StmtExprVisitor::VisitExpr_(op); |
224 | } |
225 | |
226 | void VisitStmt_(const BufferStoreNode* op) final { |
227 | UpdateBufferLCA(op->buffer.get(), ancestor_scopes_.back()); |
228 | StmtExprVisitor::VisitStmt_(op); |
229 | } |
230 | |
231 | void VisitStmt_(const BufferRealizeNode* op) final { |
232 | buffer_var_map_.emplace(op->buffer->data.get(), op->buffer.get()); |
233 | UpdateBufferLCA(op->buffer.get(), ancestor_scopes_.back()); |
234 | StmtExprVisitor::VisitStmt_(op); |
235 | } |
236 | |
237 | // Works for Load/Store and opaque access. |
238 | void VisitExpr_(const VarNode* op) final { VisitBufferVar(op); } |
239 | |
240 | // Explict to visit buffer data in Load and Store node. |
241 | void VisitExpr_(const LoadNode* op) final { |
242 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
243 | } |
244 | |
245 | void VisitStmt_(const StoreNode* op) final { |
246 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
247 | } |
248 | |
249 | void VisitBufferVar(const VarNode* op) { |
250 | auto it = buffer_var_map_.find(op); |
251 | if (it != buffer_var_map_.end()) { |
252 | UpdateBufferLCA(it->second, ancestor_scopes_.back()); |
253 | } |
254 | } |
255 | |
256 | void UpdateBufferLCA(const BufferNode* buffer, const ScopeInfo* scope) { |
257 | buffer_var_map_.emplace(buffer->data.get(), buffer); |
258 | if (match_buffers_.find(buffer) == match_buffers_.end()) { |
259 | // Ingore buffer created by block match_buffer |
260 | const ScopeInfo*& lca = buffer_lca_[buffer]; |
261 | lca = LowestCommonAncestor(lca, scope); |
262 | } |
263 | } |
264 | |
265 | void UpdateWithBlockidx() { |
266 | for (const auto& it : buffer_lca_) { |
267 | const runtime::StorageScope& scope = |
268 | runtime::StorageScope::Create(GetRef<Buffer>(it.first).scope()); |
269 | if (scope.rank == runtime::StorageRank::kGlobal) { |
270 | const ScopeInfo*& lca = buffer_lca_[it.first]; |
271 | for (const ScopeInfo* blockidx_scope : blockidx_scopes_) { |
272 | lca = LowestCommonAncestor(lca, blockidx_scope); |
273 | } |
274 | } |
275 | } |
276 | } |
277 | |
278 | static const ScopeInfo* LowestCommonAncestor(const ScopeInfo* lhs, const ScopeInfo* rhs) { |
279 | if (lhs == nullptr) return rhs; |
280 | if (rhs == nullptr) return lhs; |
281 | while (lhs->parent_scope_info != nullptr && // |
282 | rhs->parent_scope_info != nullptr && // |
283 | lhs != rhs) { |
284 | if (lhs->depth == rhs->depth) { |
285 | lhs = lhs->parent_scope_info; |
286 | rhs = rhs->parent_scope_info; |
287 | } else if (lhs->depth < rhs->depth) { |
288 | rhs = rhs->parent_scope_info; |
289 | } else { |
290 | lhs = lhs->parent_scope_info; |
291 | } |
292 | } |
293 | if (lhs->parent_scope_info == nullptr) { |
294 | return lhs; |
295 | } |
296 | if (rhs->parent_scope_info == nullptr) { |
297 | return rhs; |
298 | } |
299 | ICHECK(lhs == rhs); |
300 | return lhs; |
301 | } |
302 | |
303 | /*! \brief The ancestor scope stacks info (Block and For). The |
304 | * first element is initialized in LCADetector::Detect to represent |
305 | * the root scope. |
306 | */ |
307 | std::vector<const ScopeInfo*> ancestor_scopes_ = {}; |
308 | /*! \brief The map from Buffer to its LCA ForNode/BlockNode. */ |
309 | std::unordered_map<const BufferNode*, const ScopeInfo*> buffer_lca_ = {}; |
310 | /*! \brief The map from Buffer data to the Buffer. */ |
311 | std::unordered_map<const VarNode*, const BufferNode*> buffer_var_map_ = {}; |
312 | /*! \brief The match buffers inside blocks. */ |
313 | std::unordered_set<const BufferNode*> match_buffers_ = {}; |
314 | /*! \brief The ForNodes/BlockNodes which contain immediate `blockIdx` launch. */ |
315 | std::vector<const ScopeInfo*> blockidx_scopes_ = {}; |
316 | /*! \brief The map from loop var to the corresponding scope. */ |
317 | std::unordered_map<const VarNode*, const ScopeInfo*> loop_scope_map_ = {}; |
318 | /*! \brief Internal arena. */ |
319 | support::Arena arena_; |
320 | }; |
321 | |
322 | Map<Buffer, Optional<Stmt>> DetectBufferAccessLCA(const PrimFunc& func) { |
323 | return LCADetector::Detect(func); |
324 | } |
325 | |
326 | TVM_REGISTER_GLOBAL("tir.analysis.detect_buffer_access_lca" ).set_body_typed(DetectBufferAccessLCA); |
327 | } // namespace tir |
328 | } // namespace tvm |
329 | |