1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tir/analysis/buffer_access_lca_detector.cc
22 * \brief Detect the lowest common ancestor(LCA) of buffer access
23 */
24
25#include <tvm/tir/analysis.h>
26#include <tvm/tir/stmt_functor.h>
27
28#include "../../runtime/thread_storage_scope.h"
29#include "../../support/arena.h"
30
31namespace tvm {
32namespace tir {
33
34/*!
35 * \brief Detect the lowest common ancestor(LCA) position of Buffer access.
36 * \note
37 * - Only consider BlockNode and ForNode to be the LCA nodes.
38 * - In the LCA locator, we are aware of the buffer scope and CUDA hierarchy so that any buffer in
39 * global memory will have its buffer access LCA outside all launch sites of `blockIdx`, in order to
40 * prevent conflicts between buffer memory scopes and CUDA hierarchy.
41 */
42class LCADetector : public StmtExprVisitor {
43 public:
44 static Map<Buffer, Optional<Stmt>> Detect(const PrimFunc& func) {
45 LCADetector detector;
46 for (const auto& kv : func->buffer_map) {
47 const Buffer& buffer = kv.second;
48 detector.buffer_var_map_.emplace(buffer->data.get(), buffer.get());
49 }
50
51 // The root node must be explicitly present in the list of
52 // ancestor_scopes_. We cannot use nullptr to represent the root
53 // node, as that is also used to represent a scope that hasn't
54 // been observed before.
55 ScopeInfo root(nullptr, nullptr, 0);
56 detector.ancestor_scopes_.push_back(&root);
57
58 detector(func->body);
59 detector.UpdateWithBlockidx();
60
61 // Prepare the return
62 Map<Buffer, Optional<Stmt>> buffer_lca;
63 for (const auto& kv : detector.buffer_lca_) {
64 const Buffer& buffer = GetRef<Buffer>(kv.first);
65 const Optional<Stmt> stmt = kv.second ? GetRef<Optional<Stmt>>(kv.second->stmt) : NullOpt;
66 buffer_lca.Set(buffer, stmt);
67 }
68 return buffer_lca;
69 }
70
71 private:
72 /*!
73 * \brief The AST node information for querying LCA.
74 * \note Only BlockNode and ForNode are considered, since they are the only statements whose
75 * body can be a SeqStmt (the LCA of buffer access) in TensorIR.
76 */
77 struct ScopeInfo {
78 // The parent scope info
79 const ScopeInfo* parent_scope_info;
80 // The parent scope stmt node
81 const StmtNode* stmt;
82 // The scope depth in the AST
83 int depth;
84 ScopeInfo(const ScopeInfo* parent_info, const StmtNode* stmt, int depth)
85 : parent_scope_info(parent_info), stmt(stmt), depth(depth) {}
86 };
87
88 void VisitStmt_(const ForNode* op) final {
89 int n = ancestor_scopes_.size();
90 const ScopeInfo* parent_scope = ancestor_scopes_.back();
91 auto* current_scope = arena_.make<ScopeInfo>(parent_scope, op, n);
92
93 if (op->thread_binding.defined()) {
94 const runtime::ThreadScope& scope =
95 runtime::ThreadScope::Create(op->thread_binding.value()->thread_tag);
96 if (scope.rank == 0) {
97 blockidx_scopes_.push_back(current_scope);
98 }
99 }
100
101 ancestor_scopes_.push_back(current_scope);
102 loop_scope_map_.insert({op->loop_var.get(), current_scope});
103 StmtExprVisitor::VisitStmt_(op);
104 ancestor_scopes_.pop_back();
105 loop_scope_map_.erase(op->loop_var.get());
106 }
107
108 void VisitStmt_(const BlockRealizeNode* op) final {
109 const BlockNode* block = op->block.get();
110 int n = ancestor_scopes_.size();
111 for (const Buffer& buf : block->alloc_buffers) {
112 buffer_var_map_.emplace(buf->data.get(), buf.get());
113 }
114
115 const ScopeInfo* parent_scope = ancestor_scopes_.back();
116 auto* current_scope = arena_.make<ScopeInfo>(parent_scope, block, n);
117
118 ancestor_scopes_.push_back(current_scope);
119
120 // For each accessed buffer of the block, update the buffer's lca to
121 // the lowest inclusive stmt position, which should dominate all loops
122 // related to the accessed opaque block iter vars in buffer indices.
123 UpdateDominateScopeOfOpaqueIter(op);
124
125 // Update match_buffers
126 for (const MatchBufferRegion& match_buffer : block->match_buffers) {
127 UpdateBufferLCA(match_buffer->source->buffer.get(), ancestor_scopes_.back());
128 match_buffers_.insert(match_buffer->buffer.get());
129 }
130
131 StmtExprVisitor::VisitStmt_(op);
132 ancestor_scopes_.pop_back();
133 }
134
135 void UpdateDominateScopeOfOpaqueIter(const BlockRealizeNode* block_realize) {
136 // map opaque iter var to the scope which dominate all loop carried dependencies.
137 std::unordered_map<const VarNode*, const ScopeInfo*> itervar_to_dom_scope;
138
139 // function to collect `itervar_to_dom_scope`, the result scope for each block
140 // iter var should be above all loop scopes the opaque iter var binding relates to.
141 auto do_collect_itervar_scope = [this, &itervar_to_dom_scope](const IterVar& itervar,
142 const PrimExpr& binding) {
143 PostOrderVisit(binding, [this, &itervar_to_dom_scope, &itervar](const ObjectRef& obj) {
144 if (const VarNode* loop_var = obj.as<VarNode>()) {
145 auto it = loop_scope_map_.find(loop_var);
146 if (it == loop_scope_map_.end()) {
147 return;
148 }
149 const ScopeInfo* scope = it->second->parent_scope_info;
150 // find the highest loop scope the iter var binding has related to.
151 auto dom_scope_it = itervar_to_dom_scope.find(itervar->var.get());
152 if (dom_scope_it == itervar_to_dom_scope.end()) {
153 itervar_to_dom_scope.insert(dom_scope_it, {itervar->var.get(), scope});
154 } else if (scope->depth < dom_scope_it->second->depth) {
155 dom_scope_it->second = scope;
156 }
157 }
158 });
159 };
160
161 // function to update lca scope of the buffer with loop carried dependent buffer accesses.
162 // the result scope should be above all loop scopes the accessed opaque block iter vars
163 // relate to, which is record in `itervar_to_dom_scope`.
164 auto do_update = [this, &itervar_to_dom_scope](const BufferRegion& region) {
165 const Buffer& buffer = region->buffer;
166 const ScopeInfo* scope = ancestor_scopes_.back();
167
168 auto handle_itervar = [&itervar_to_dom_scope, &scope](const ObjectRef& obj) {
169 if (const VarNode* iter_var = obj.as<VarNode>()) {
170 auto dom_scope_it = itervar_to_dom_scope.find(iter_var);
171 if (dom_scope_it == itervar_to_dom_scope.end()) {
172 return;
173 }
174 // find the highest loop scope the accessed buffer index has
175 // loop carried dependencies to (via opaque iter var binding).
176 if (dom_scope_it->second->depth < scope->depth) {
177 scope = dom_scope_it->second;
178 }
179 }
180 };
181
182 // visit region min and max to find the lowest legal lca scope
183 for (const Range& range : region->region) {
184 PostOrderVisit(range->min, handle_itervar);
185 PostOrderVisit(range->min + range->extent - 1, handle_itervar);
186 }
187 UpdateBufferLCA(buffer.get(), scope);
188 };
189
190 // do collect and update
191 const Block& block = block_realize->block;
192 for (size_t i = 0; i < block_realize->iter_values.size(); ++i) {
193 const IterVar& iter_var = block->iter_vars[i];
194 if (iter_var->iter_type != IterVarType::kDataPar &&
195 iter_var->iter_type != IterVarType::kCommReduce) {
196 do_collect_itervar_scope(iter_var, block_realize->iter_values[i]);
197 }
198 }
199 if (!itervar_to_dom_scope.empty()) {
200 for (const auto& read : block->reads) {
201 do_update(read);
202 }
203 for (const auto& write : block->writes) {
204 do_update(write);
205 }
206 }
207 }
208
209 void VisitStmt_(const AttrStmtNode* op) final {
210 if (op->attr_key == attr::thread_extent) {
211 const auto* iter = op->node.as<IterVarNode>();
212 ICHECK_NOTNULL(iter);
213 const runtime::ThreadScope& scope = runtime::ThreadScope::Create(iter->thread_tag);
214 if (scope.rank == 0) {
215 blockidx_scopes_.push_back(ancestor_scopes_.back());
216 }
217 }
218 StmtExprVisitor::VisitStmt_(op);
219 }
220
221 void VisitExpr_(const BufferLoadNode* op) final {
222 UpdateBufferLCA(op->buffer.get(), ancestor_scopes_.back());
223 StmtExprVisitor::VisitExpr_(op);
224 }
225
226 void VisitStmt_(const BufferStoreNode* op) final {
227 UpdateBufferLCA(op->buffer.get(), ancestor_scopes_.back());
228 StmtExprVisitor::VisitStmt_(op);
229 }
230
231 void VisitStmt_(const BufferRealizeNode* op) final {
232 buffer_var_map_.emplace(op->buffer->data.get(), op->buffer.get());
233 UpdateBufferLCA(op->buffer.get(), ancestor_scopes_.back());
234 StmtExprVisitor::VisitStmt_(op);
235 }
236
237 // Works for Load/Store and opaque access.
238 void VisitExpr_(const VarNode* op) final { VisitBufferVar(op); }
239
240 // Explict to visit buffer data in Load and Store node.
241 void VisitExpr_(const LoadNode* op) final {
242 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
243 }
244
245 void VisitStmt_(const StoreNode* op) final {
246 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
247 }
248
249 void VisitBufferVar(const VarNode* op) {
250 auto it = buffer_var_map_.find(op);
251 if (it != buffer_var_map_.end()) {
252 UpdateBufferLCA(it->second, ancestor_scopes_.back());
253 }
254 }
255
256 void UpdateBufferLCA(const BufferNode* buffer, const ScopeInfo* scope) {
257 buffer_var_map_.emplace(buffer->data.get(), buffer);
258 if (match_buffers_.find(buffer) == match_buffers_.end()) {
259 // Ingore buffer created by block match_buffer
260 const ScopeInfo*& lca = buffer_lca_[buffer];
261 lca = LowestCommonAncestor(lca, scope);
262 }
263 }
264
265 void UpdateWithBlockidx() {
266 for (const auto& it : buffer_lca_) {
267 const runtime::StorageScope& scope =
268 runtime::StorageScope::Create(GetRef<Buffer>(it.first).scope());
269 if (scope.rank == runtime::StorageRank::kGlobal) {
270 const ScopeInfo*& lca = buffer_lca_[it.first];
271 for (const ScopeInfo* blockidx_scope : blockidx_scopes_) {
272 lca = LowestCommonAncestor(lca, blockidx_scope);
273 }
274 }
275 }
276 }
277
278 static const ScopeInfo* LowestCommonAncestor(const ScopeInfo* lhs, const ScopeInfo* rhs) {
279 if (lhs == nullptr) return rhs;
280 if (rhs == nullptr) return lhs;
281 while (lhs->parent_scope_info != nullptr && //
282 rhs->parent_scope_info != nullptr && //
283 lhs != rhs) {
284 if (lhs->depth == rhs->depth) {
285 lhs = lhs->parent_scope_info;
286 rhs = rhs->parent_scope_info;
287 } else if (lhs->depth < rhs->depth) {
288 rhs = rhs->parent_scope_info;
289 } else {
290 lhs = lhs->parent_scope_info;
291 }
292 }
293 if (lhs->parent_scope_info == nullptr) {
294 return lhs;
295 }
296 if (rhs->parent_scope_info == nullptr) {
297 return rhs;
298 }
299 ICHECK(lhs == rhs);
300 return lhs;
301 }
302
303 /*! \brief The ancestor scope stacks info (Block and For). The
304 * first element is initialized in LCADetector::Detect to represent
305 * the root scope.
306 */
307 std::vector<const ScopeInfo*> ancestor_scopes_ = {};
308 /*! \brief The map from Buffer to its LCA ForNode/BlockNode. */
309 std::unordered_map<const BufferNode*, const ScopeInfo*> buffer_lca_ = {};
310 /*! \brief The map from Buffer data to the Buffer. */
311 std::unordered_map<const VarNode*, const BufferNode*> buffer_var_map_ = {};
312 /*! \brief The match buffers inside blocks. */
313 std::unordered_set<const BufferNode*> match_buffers_ = {};
314 /*! \brief The ForNodes/BlockNodes which contain immediate `blockIdx` launch. */
315 std::vector<const ScopeInfo*> blockidx_scopes_ = {};
316 /*! \brief The map from loop var to the corresponding scope. */
317 std::unordered_map<const VarNode*, const ScopeInfo*> loop_scope_map_ = {};
318 /*! \brief Internal arena. */
319 support::Arena arena_;
320};
321
322Map<Buffer, Optional<Stmt>> DetectBufferAccessLCA(const PrimFunc& func) {
323 return LCADetector::Detect(func);
324}
325
326TVM_REGISTER_GLOBAL("tir.analysis.detect_buffer_access_lca").set_body_typed(DetectBufferAccessLCA);
327} // namespace tir
328} // namespace tvm
329