1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tir/analysis/block_region_detector.cc |
22 | * \brief Detect block read/write regions by visiting its body |
23 | */ |
24 | |
25 | #include <tvm/arith/analyzer.h> |
26 | #include <tvm/tir/op.h> |
27 | #include <tvm/tir/stmt_functor.h> |
28 | |
29 | #include "../transforms/ir_utils.h" |
30 | namespace tvm { |
31 | namespace tir { |
32 | |
33 | /*! |
34 | * \brief Detect which regions of tensors in this block are read or written to. Regions are sorted |
35 | * by order of appearance in the AST. \note This detector can only visit blocks and will not visit |
36 | * child blocks recursively |
37 | */ |
38 | class BlockReadWriteDetector : public StmtExprVisitor { |
39 | public: |
40 | explicit BlockReadWriteDetector(const Map<Var, Buffer>& buffer_var_map) |
41 | : buffer_var_map_(buffer_var_map) {} |
42 | |
43 | /*! \brief Return read regions of the block */ |
44 | Array<BufferRegion> CollectReads( |
45 | const std::unordered_set<const BufferNode*>* excluded_buffers = nullptr); |
46 | /*! \brief Return write regions of the block */ |
47 | Array<BufferRegion> CollectWrites( |
48 | const std::unordered_set<const BufferNode*>* excluded_buffers = nullptr); |
49 | /*! |
50 | * \brief Return opaque buffer regions of the block |
51 | * \note The buffer accessed by load/store or call with buffer.data will |
52 | * be marked as opaque. |
53 | */ |
54 | Array<BufferRegion> CollectOpaques(); |
55 | /*! \brief overload operator() to make sure it accepts a block node */ |
56 | void operator()(const Stmt& stmt); |
57 | |
58 | private: |
59 | /*! \brief Iteration range for loop_vars */ |
60 | std::unordered_map<const VarNode*, arith::IntSet> dom_map_; |
61 | /*! \brief Extra iteration range hint for free vars */ |
62 | std::unordered_map<const VarNode*, arith::IntSet> hint_map_; |
63 | /*! \brief The buffers that the current block reads */ |
64 | std::vector<Buffer> read_buffers_; |
65 | /*! \brief The buffers that the current block writes */ |
66 | std::vector<Buffer> writes_buffers_; |
67 | /*! \brief The opaque buffer which is access by buffer.data */ |
68 | std::vector<Buffer> opaque_buffers_; |
69 | /*! \brief The read regions of the current block */ |
70 | std::vector<std::vector<tvm::arith::IntSet>> read_regions_; |
71 | /*! \brief The write regions of the current block */ |
72 | std::vector<std::vector<tvm::arith::IntSet>> write_regions_; |
73 | /*! \brief The opaque regions of the current block */ |
74 | std::vector<std::vector<tvm::arith::IntSet>> opaque_regions_; |
75 | /*! \brief The outside buffer data mapping to its buffer */ |
76 | Map<Var, Buffer> buffer_var_map_; |
77 | /*! \brief The target buffer var mapping to its matching */ |
78 | std::unordered_map<const VarNode*, MatchBufferRegion> match_buffers_; |
79 | /*! \brief The analyzer for simplifying*/ |
80 | arith::Analyzer analyzer_; |
81 | |
82 | /*! |
83 | * \brief Update read/write buffers and regions with provided buffer and region |
84 | * \param buffers The buffers should be updated |
85 | * \param regions The access regions should be updated |
86 | * \param buffer The provided buffer |
87 | * \param region The provided region |
88 | */ |
89 | void Update(std::vector<Buffer>* buffers, std::vector<std::vector<arith::IntSet>>* regions, |
90 | Buffer buffer, std::vector<arith::IntSet> region); |
91 | |
92 | /*! \brief Helper function to collect access regions. */ |
93 | Array<BufferRegion> CollectRegions( |
94 | const std::vector<Buffer>& buffers, |
95 | const std::vector<std::vector<tvm::arith::IntSet>>& regions, |
96 | const std::unordered_set<const BufferNode*>* excluded_buffers = nullptr); |
97 | |
98 | /*! \brief Helper function to convert matched access region to source region. */ |
99 | std::vector<arith::IntSet> ConvertMatchedRegion(const MatchBufferRegion& match_buffer, |
100 | const std::vector<arith::IntSet>& int_sets) const; |
101 | |
102 | /*! \brief Helper function to update a opaque access. */ |
103 | void UpdateOpaque(const Var& buffer_var); |
104 | |
105 | /*! \brief Helper function to relax the buffer indices */ |
106 | arith::IntSet RelaxAccessIndex(const PrimExpr& index); |
107 | |
108 | void VisitStmt_(const ForNode* op) override; |
109 | void VisitStmt_(const IfThenElseNode* op) override; |
110 | void VisitStmt_(const BlockRealizeNode* op) override; |
111 | void VisitStmt_(const BufferStoreNode* op) override; |
112 | void VisitStmt_(const StoreNode* op) override; |
113 | void VisitExpr_(const BufferLoadNode* op) override; |
114 | void VisitExpr_(const LoadNode* op) override; |
115 | void VisitExpr_(const VarNode* op) override; |
116 | void VisitExpr_(const CallNode* op) override; |
117 | }; |
118 | |
119 | void BlockReadWriteDetector::operator()(const Stmt& stmt) { |
120 | const auto* block = stmt.as<BlockNode>(); |
121 | ICHECK(block != nullptr) << "Only visiting Blocks is allowed, but got " << stmt->GetTypeKey(); |
122 | for (const MatchBufferRegion& match_buffer : block->match_buffers) { |
123 | const Var& target_var = match_buffer->buffer->data; |
124 | const Var& source_var = match_buffer->source->buffer->data; |
125 | if (buffer_var_map_.find(source_var) != buffer_var_map_.end()) { |
126 | match_buffers_[target_var.get()] = match_buffer; |
127 | buffer_var_map_.Set(target_var, match_buffer->buffer); |
128 | } |
129 | } |
130 | StmtExprVisitor::operator()(stmt); |
131 | } |
132 | |
133 | Array<BufferRegion> BlockReadWriteDetector::CollectReads( |
134 | const std::unordered_set<const BufferNode*>* excluded_buffers) { |
135 | return CollectRegions(read_buffers_, read_regions_, excluded_buffers); |
136 | } |
137 | |
138 | Array<BufferRegion> BlockReadWriteDetector::CollectWrites( |
139 | const std::unordered_set<const BufferNode*>* excluded_buffers) { |
140 | return CollectRegions(writes_buffers_, write_regions_, excluded_buffers); |
141 | } |
142 | |
143 | Array<BufferRegion> BlockReadWriteDetector::CollectOpaques() { |
144 | return CollectRegions(opaque_buffers_, opaque_regions_); |
145 | } |
146 | |
147 | void BlockReadWriteDetector::VisitExpr_(const VarNode* op) { UpdateOpaque(GetRef<Var>(op)); } |
148 | |
149 | void BlockReadWriteDetector::VisitExpr_(const LoadNode* op) { |
150 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
151 | } |
152 | |
153 | void BlockReadWriteDetector::VisitExpr_(const BufferLoadNode* op) { |
154 | std::vector<arith::IntSet> relaxed_region; |
155 | for (const PrimExpr& index : op->indices) { |
156 | relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(index), dom_map_)); |
157 | } |
158 | Update(&read_buffers_, &read_regions_, op->buffer, relaxed_region); |
159 | ExprVisitor::VisitExpr_(op); |
160 | } |
161 | |
162 | void BlockReadWriteDetector::VisitStmt_(const ForNode* op) { |
163 | Range range = Range::FromMinExtent(op->min, op->extent); |
164 | dom_map_[op->loop_var.get()] = arith::IntSet::FromRange(range); |
165 | StmtVisitor::VisitStmt_(op); |
166 | dom_map_.erase(op->loop_var.get()); |
167 | } |
168 | |
169 | void BlockReadWriteDetector::VisitStmt_(const IfThenElseNode* op) { |
170 | VisitExpr(op->condition); |
171 | { |
172 | // Visit then branch |
173 | With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, true); |
174 | StmtExprVisitor::VisitStmt(op->then_case); |
175 | } |
176 | if (op->else_case) { |
177 | // Visit else branch |
178 | With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, false); |
179 | StmtExprVisitor::VisitStmt(op->else_case.value()); |
180 | } |
181 | } |
182 | |
183 | void BlockReadWriteDetector::VisitExpr_(const CallNode* op) { |
184 | if (op->op.same_as(builtin::tvm_access_ptr())) { |
185 | const VarNode* buffer_var = op->args[1].as<VarNode>(); |
186 | const IntImmNode* access_mask = op->args[4].as<IntImmNode>(); |
187 | if (buffer_var && access_mask) { |
188 | auto it = buffer_var_map_.find(GetRef<Var>(buffer_var)); |
189 | if (it != buffer_var_map_.end()) { |
190 | const Buffer& buffer = (*it).second; |
191 | const BufferRegion buffer_region = BufferRegion::FullRegion(buffer); |
192 | const Region& region = buffer_region->region; |
193 | std::vector<arith::IntSet> int_set; |
194 | int_set.reserve(region.size()); |
195 | for (const Range& range : region) { |
196 | int_set.push_back(arith::EvalSet(range, dom_map_)); |
197 | } |
198 | // read access, write access or opaque access |
199 | if ((access_mask->value & 1) && (access_mask->value & 2)) { |
200 | Update(&opaque_buffers_, &opaque_regions_, buffer, int_set); |
201 | } else if (access_mask->value & 1) { |
202 | Update(&read_buffers_, &read_regions_, buffer, int_set); |
203 | } else if (access_mask->value & 2) { |
204 | Update(&writes_buffers_, &write_regions_, buffer, int_set); |
205 | } |
206 | } |
207 | } else { |
208 | StmtExprVisitor::VisitExpr_(op); |
209 | } |
210 | return; |
211 | } |
212 | if (op->op.same_as(builtin::if_then_else())) { |
213 | VisitExpr(op->args[0]); |
214 | { |
215 | // Visit then branch |
216 | With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_, true); |
217 | StmtExprVisitor::VisitExpr(op->args[1]); |
218 | } |
219 | { |
220 | // Visit else branch |
221 | With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_, false); |
222 | StmtExprVisitor::VisitExpr(op->args[2]); |
223 | } |
224 | return; |
225 | } |
226 | StmtExprVisitor::VisitExpr_(op); |
227 | } |
228 | |
229 | void BlockReadWriteDetector::VisitStmt_(const StoreNode* op) { |
230 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
231 | } |
232 | |
233 | void BlockReadWriteDetector::VisitStmt_(const BufferStoreNode* op) { |
234 | std::vector<arith::IntSet> relaxed_region; |
235 | for (const PrimExpr& index : op->indices) { |
236 | relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(index), dom_map_)); |
237 | } |
238 | Update(&writes_buffers_, &write_regions_, op->buffer, relaxed_region); |
239 | StmtVisitor::VisitStmt_(op); |
240 | } |
241 | |
242 | void BlockReadWriteDetector::VisitStmt_(const BlockRealizeNode* op) { |
243 | /*! \note detector will not visit child block recursively, so it will stop here */ |
244 | std::unordered_map<const VarNode*, PrimExpr> vmap; |
245 | for (size_t i = 0; i < op->block->iter_vars.size(); ++i) { |
246 | vmap[op->block->iter_vars[i]->var.get()] = op->iter_values[i]; |
247 | } |
248 | for (const auto& read : op->block->reads) { |
249 | std::vector<arith::IntSet> relaxed_region; |
250 | for (const auto& range : read->region) { |
251 | relaxed_region.push_back( |
252 | arith::EvalSet(arith::IntSet::FromRange(Range::FromMinExtent( |
253 | Substitute(range->min, vmap), Substitute(range->extent, vmap))), |
254 | dom_map_)); |
255 | } |
256 | Update(&read_buffers_, &read_regions_, read->buffer, relaxed_region); |
257 | } |
258 | for (const auto& write : op->block->writes) { |
259 | std::vector<arith::IntSet> relaxed_region; |
260 | for (const auto& range : write->region) { |
261 | relaxed_region.push_back( |
262 | arith::EvalSet(arith::IntSet::FromRange(Range::FromMinExtent( |
263 | Substitute(range->min, vmap), Substitute(range->extent, vmap))), |
264 | dom_map_)); |
265 | } |
266 | Update(&writes_buffers_, &write_regions_, write->buffer, relaxed_region); |
267 | } |
268 | } |
269 | |
270 | std::vector<arith::IntSet> BlockReadWriteDetector::ConvertMatchedRegion( |
271 | const MatchBufferRegion& match_buffer, const std::vector<arith::IntSet>& int_sets) const { |
272 | const Buffer& buffer = match_buffer->buffer; |
273 | |
274 | Region region; |
275 | region.reserve(int_sets.size()); |
276 | ICHECK_EQ(buffer->shape.size(), int_sets.size()); |
277 | for (size_t i = 0; i < int_sets.size(); ++i) { |
278 | const tvm::arith::IntSet& int_set = int_sets[i]; |
279 | region.push_back(int_set.CoverRange(Range::FromMinExtent(0, buffer->shape[i]))); |
280 | } |
281 | |
282 | region = ConvertRegion(match_buffer, region); |
283 | |
284 | std::vector<arith::IntSet> result; |
285 | result.reserve(region.size()); |
286 | for (const Range& range : region) { |
287 | result.push_back(arith::EvalSet(range, dom_map_)); |
288 | } |
289 | return result; |
290 | } |
291 | |
292 | void BlockReadWriteDetector::Update(std::vector<Buffer>* buffers, |
293 | std::vector<std::vector<arith::IntSet>>* regions, Buffer buffer, |
294 | std::vector<arith::IntSet> region) { |
295 | if (buffer_var_map_.find(buffer->data) == buffer_var_map_.end()) return; |
296 | // Handle match_buffer remap |
297 | auto it = match_buffers_.find(buffer->data.get()); |
298 | if (it != match_buffers_.end()) { |
299 | const MatchBufferRegion& match_buffer = it->second; |
300 | buffer = match_buffer->source->buffer; |
301 | region = ConvertMatchedRegion(match_buffer, std::move(region)); |
302 | } |
303 | ICHECK_EQ(buffers->size(), regions->size()) |
304 | << " Expected the buffer and regions to have the same size " ; |
305 | for (size_t i = 0; i < regions->size(); ++i) { |
306 | if ((*buffers)[i].same_as(buffer)) { |
307 | ICHECK_EQ((*regions)[i].size(), region.size()) << "Inconsistent buffer dimension" ; |
308 | for (size_t j = 0; j < region.size(); ++j) { |
309 | (*regions)[i][j] = arith::Union({(*regions)[i][j], region[j]}); |
310 | } |
311 | return; |
312 | } |
313 | } |
314 | buffers->push_back(std::move(buffer)); |
315 | regions->push_back(std::move(region)); |
316 | } |
317 | |
318 | Array<BufferRegion> BlockReadWriteDetector::CollectRegions( |
319 | const std::vector<Buffer>& buffers, const std::vector<std::vector<tvm::arith::IntSet>>& regions, |
320 | const std::unordered_set<const BufferNode*>* excluded_buffers) { |
321 | ICHECK_EQ(buffers.size(), regions.size()); |
322 | Array<BufferRegion> res; |
323 | res.reserve(buffers.size()); |
324 | for (size_t i = 0; i < regions.size(); ++i) { |
325 | if (excluded_buffers != nullptr && excluded_buffers->count(buffers[i].get())) { |
326 | continue; |
327 | } |
328 | Array<Range> region; |
329 | region.reserve(regions[i].size()); |
330 | ICHECK_EQ(buffers[i]->shape.size(), regions[i].size()); |
331 | for (size_t j = 0; j < regions[i].size(); j++) { |
332 | const tvm::arith::IntSet& range = regions[i][j]; |
333 | region.push_back(range.CoverRange(Range::FromMinExtent(0, buffers[i]->shape[j]))); |
334 | } |
335 | res.push_back(BufferRegion(buffers[i], region)); |
336 | } |
337 | return res; |
338 | } |
339 | |
340 | void BlockReadWriteDetector::UpdateOpaque(const Var& buffer_var) { |
341 | auto it = buffer_var_map_.find(buffer_var); |
342 | if (it != buffer_var_map_.end()) { |
343 | const Buffer& buffer = (*it).second; |
344 | const BufferRegion buffer_region = BufferRegion::FullRegion(buffer); |
345 | const Region& region = buffer_region->region; |
346 | std::vector<arith::IntSet> int_set; |
347 | int_set.reserve(region.size()); |
348 | for (const Range& range : region) { |
349 | int_set.push_back(arith::EvalSet(range, dom_map_)); |
350 | } |
351 | Update(&opaque_buffers_, &opaque_regions_, buffer, int_set); |
352 | } |
353 | } |
354 | |
355 | Array<Array<BufferRegion>> GetBlockAccessRegion(const Block& block, |
356 | const Map<Var, Buffer>& buffer_var_map) { |
357 | BlockReadWriteDetector detector(buffer_var_map); |
358 | detector(block); |
359 | Array<BufferRegion> writes = detector.CollectWrites(); |
360 | std::unordered_set<const BufferNode*> excluded_buffers; |
361 | // exclude write buffers from read regions for reductions if init block is defined. |
362 | if (block->init.defined()) { |
363 | for (const BufferRegion& write_access : writes) { |
364 | excluded_buffers.insert(write_access->buffer.get()); |
365 | } |
366 | } |
367 | Array<BufferRegion> reads = detector.CollectReads(&excluded_buffers); |
368 | Array<BufferRegion> opaques = detector.CollectOpaques(); |
369 | return {reads, writes, opaques}; |
370 | } |
371 | |
372 | Array<Array<BufferRegion>> GetBlockReadWriteRegion(const Block& block, |
373 | const Map<Var, Buffer>& buffer_var_map) { |
374 | BlockReadWriteDetector detector(buffer_var_map); |
375 | detector(block); |
376 | Array<BufferRegion> opaques = detector.CollectOpaques(); |
377 | std::unordered_set<const BufferNode*> excluded_buffers; |
378 | for (const BufferRegion& opaque_access : opaques) { |
379 | excluded_buffers.insert(opaque_access->buffer.get()); |
380 | } |
381 | Array<BufferRegion> writes = detector.CollectWrites(&excluded_buffers); |
382 | if (block->init.defined()) { |
383 | for (const BufferRegion& write_access : writes) { |
384 | excluded_buffers.insert(write_access->buffer.get()); |
385 | } |
386 | } |
387 | Array<BufferRegion> reads = detector.CollectReads(&excluded_buffers); |
388 | for (const BufferRegion& opaque_access : opaques) { |
389 | reads.push_back(opaque_access); |
390 | writes.push_back(opaque_access); |
391 | } |
392 | return {reads, writes}; |
393 | } |
394 | |
395 | TVM_REGISTER_GLOBAL("tir.analysis.GetBlockAccessRegion" ).set_body_typed(GetBlockAccessRegion); |
396 | TVM_REGISTER_GLOBAL("tir.analysis.GetBlockReadWriteRegion" ).set_body_typed(GetBlockReadWriteRegion); |
397 | |
398 | } // namespace tir |
399 | } // namespace tvm |
400 | |