1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tir/analysis/block_region_detector.cc
22 * \brief Detect block read/write regions by visiting its body
23 */
24
25#include <tvm/arith/analyzer.h>
26#include <tvm/tir/op.h>
27#include <tvm/tir/stmt_functor.h>
28
29#include "../transforms/ir_utils.h"
30namespace tvm {
31namespace tir {
32
33/*!
34 * \brief Detect which regions of tensors in this block are read or written to. Regions are sorted
35 * by order of appearance in the AST. \note This detector can only visit blocks and will not visit
36 * child blocks recursively
37 */
38class BlockReadWriteDetector : public StmtExprVisitor {
39 public:
40 explicit BlockReadWriteDetector(const Map<Var, Buffer>& buffer_var_map)
41 : buffer_var_map_(buffer_var_map) {}
42
43 /*! \brief Return read regions of the block */
44 Array<BufferRegion> CollectReads(
45 const std::unordered_set<const BufferNode*>* excluded_buffers = nullptr);
46 /*! \brief Return write regions of the block */
47 Array<BufferRegion> CollectWrites(
48 const std::unordered_set<const BufferNode*>* excluded_buffers = nullptr);
49 /*!
50 * \brief Return opaque buffer regions of the block
51 * \note The buffer accessed by load/store or call with buffer.data will
52 * be marked as opaque.
53 */
54 Array<BufferRegion> CollectOpaques();
55 /*! \brief overload operator() to make sure it accepts a block node */
56 void operator()(const Stmt& stmt);
57
58 private:
59 /*! \brief Iteration range for loop_vars */
60 std::unordered_map<const VarNode*, arith::IntSet> dom_map_;
61 /*! \brief Extra iteration range hint for free vars */
62 std::unordered_map<const VarNode*, arith::IntSet> hint_map_;
63 /*! \brief The buffers that the current block reads */
64 std::vector<Buffer> read_buffers_;
65 /*! \brief The buffers that the current block writes */
66 std::vector<Buffer> writes_buffers_;
67 /*! \brief The opaque buffer which is access by buffer.data */
68 std::vector<Buffer> opaque_buffers_;
69 /*! \brief The read regions of the current block */
70 std::vector<std::vector<tvm::arith::IntSet>> read_regions_;
71 /*! \brief The write regions of the current block */
72 std::vector<std::vector<tvm::arith::IntSet>> write_regions_;
73 /*! \brief The opaque regions of the current block */
74 std::vector<std::vector<tvm::arith::IntSet>> opaque_regions_;
75 /*! \brief The outside buffer data mapping to its buffer */
76 Map<Var, Buffer> buffer_var_map_;
77 /*! \brief The target buffer var mapping to its matching */
78 std::unordered_map<const VarNode*, MatchBufferRegion> match_buffers_;
79 /*! \brief The analyzer for simplifying*/
80 arith::Analyzer analyzer_;
81
82 /*!
83 * \brief Update read/write buffers and regions with provided buffer and region
84 * \param buffers The buffers should be updated
85 * \param regions The access regions should be updated
86 * \param buffer The provided buffer
87 * \param region The provided region
88 */
89 void Update(std::vector<Buffer>* buffers, std::vector<std::vector<arith::IntSet>>* regions,
90 Buffer buffer, std::vector<arith::IntSet> region);
91
92 /*! \brief Helper function to collect access regions. */
93 Array<BufferRegion> CollectRegions(
94 const std::vector<Buffer>& buffers,
95 const std::vector<std::vector<tvm::arith::IntSet>>& regions,
96 const std::unordered_set<const BufferNode*>* excluded_buffers = nullptr);
97
98 /*! \brief Helper function to convert matched access region to source region. */
99 std::vector<arith::IntSet> ConvertMatchedRegion(const MatchBufferRegion& match_buffer,
100 const std::vector<arith::IntSet>& int_sets) const;
101
102 /*! \brief Helper function to update a opaque access. */
103 void UpdateOpaque(const Var& buffer_var);
104
105 /*! \brief Helper function to relax the buffer indices */
106 arith::IntSet RelaxAccessIndex(const PrimExpr& index);
107
108 void VisitStmt_(const ForNode* op) override;
109 void VisitStmt_(const IfThenElseNode* op) override;
110 void VisitStmt_(const BlockRealizeNode* op) override;
111 void VisitStmt_(const BufferStoreNode* op) override;
112 void VisitStmt_(const StoreNode* op) override;
113 void VisitExpr_(const BufferLoadNode* op) override;
114 void VisitExpr_(const LoadNode* op) override;
115 void VisitExpr_(const VarNode* op) override;
116 void VisitExpr_(const CallNode* op) override;
117};
118
119void BlockReadWriteDetector::operator()(const Stmt& stmt) {
120 const auto* block = stmt.as<BlockNode>();
121 ICHECK(block != nullptr) << "Only visiting Blocks is allowed, but got " << stmt->GetTypeKey();
122 for (const MatchBufferRegion& match_buffer : block->match_buffers) {
123 const Var& target_var = match_buffer->buffer->data;
124 const Var& source_var = match_buffer->source->buffer->data;
125 if (buffer_var_map_.find(source_var) != buffer_var_map_.end()) {
126 match_buffers_[target_var.get()] = match_buffer;
127 buffer_var_map_.Set(target_var, match_buffer->buffer);
128 }
129 }
130 StmtExprVisitor::operator()(stmt);
131}
132
133Array<BufferRegion> BlockReadWriteDetector::CollectReads(
134 const std::unordered_set<const BufferNode*>* excluded_buffers) {
135 return CollectRegions(read_buffers_, read_regions_, excluded_buffers);
136}
137
138Array<BufferRegion> BlockReadWriteDetector::CollectWrites(
139 const std::unordered_set<const BufferNode*>* excluded_buffers) {
140 return CollectRegions(writes_buffers_, write_regions_, excluded_buffers);
141}
142
143Array<BufferRegion> BlockReadWriteDetector::CollectOpaques() {
144 return CollectRegions(opaque_buffers_, opaque_regions_);
145}
146
147void BlockReadWriteDetector::VisitExpr_(const VarNode* op) { UpdateOpaque(GetRef<Var>(op)); }
148
149void BlockReadWriteDetector::VisitExpr_(const LoadNode* op) {
150 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
151}
152
153void BlockReadWriteDetector::VisitExpr_(const BufferLoadNode* op) {
154 std::vector<arith::IntSet> relaxed_region;
155 for (const PrimExpr& index : op->indices) {
156 relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(index), dom_map_));
157 }
158 Update(&read_buffers_, &read_regions_, op->buffer, relaxed_region);
159 ExprVisitor::VisitExpr_(op);
160}
161
162void BlockReadWriteDetector::VisitStmt_(const ForNode* op) {
163 Range range = Range::FromMinExtent(op->min, op->extent);
164 dom_map_[op->loop_var.get()] = arith::IntSet::FromRange(range);
165 StmtVisitor::VisitStmt_(op);
166 dom_map_.erase(op->loop_var.get());
167}
168
169void BlockReadWriteDetector::VisitStmt_(const IfThenElseNode* op) {
170 VisitExpr(op->condition);
171 {
172 // Visit then branch
173 With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, true);
174 StmtExprVisitor::VisitStmt(op->then_case);
175 }
176 if (op->else_case) {
177 // Visit else branch
178 With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, false);
179 StmtExprVisitor::VisitStmt(op->else_case.value());
180 }
181}
182
183void BlockReadWriteDetector::VisitExpr_(const CallNode* op) {
184 if (op->op.same_as(builtin::tvm_access_ptr())) {
185 const VarNode* buffer_var = op->args[1].as<VarNode>();
186 const IntImmNode* access_mask = op->args[4].as<IntImmNode>();
187 if (buffer_var && access_mask) {
188 auto it = buffer_var_map_.find(GetRef<Var>(buffer_var));
189 if (it != buffer_var_map_.end()) {
190 const Buffer& buffer = (*it).second;
191 const BufferRegion buffer_region = BufferRegion::FullRegion(buffer);
192 const Region& region = buffer_region->region;
193 std::vector<arith::IntSet> int_set;
194 int_set.reserve(region.size());
195 for (const Range& range : region) {
196 int_set.push_back(arith::EvalSet(range, dom_map_));
197 }
198 // read access, write access or opaque access
199 if ((access_mask->value & 1) && (access_mask->value & 2)) {
200 Update(&opaque_buffers_, &opaque_regions_, buffer, int_set);
201 } else if (access_mask->value & 1) {
202 Update(&read_buffers_, &read_regions_, buffer, int_set);
203 } else if (access_mask->value & 2) {
204 Update(&writes_buffers_, &write_regions_, buffer, int_set);
205 }
206 }
207 } else {
208 StmtExprVisitor::VisitExpr_(op);
209 }
210 return;
211 }
212 if (op->op.same_as(builtin::if_then_else())) {
213 VisitExpr(op->args[0]);
214 {
215 // Visit then branch
216 With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_, true);
217 StmtExprVisitor::VisitExpr(op->args[1]);
218 }
219 {
220 // Visit else branch
221 With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_, false);
222 StmtExprVisitor::VisitExpr(op->args[2]);
223 }
224 return;
225 }
226 StmtExprVisitor::VisitExpr_(op);
227}
228
229void BlockReadWriteDetector::VisitStmt_(const StoreNode* op) {
230 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
231}
232
233void BlockReadWriteDetector::VisitStmt_(const BufferStoreNode* op) {
234 std::vector<arith::IntSet> relaxed_region;
235 for (const PrimExpr& index : op->indices) {
236 relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(index), dom_map_));
237 }
238 Update(&writes_buffers_, &write_regions_, op->buffer, relaxed_region);
239 StmtVisitor::VisitStmt_(op);
240}
241
242void BlockReadWriteDetector::VisitStmt_(const BlockRealizeNode* op) {
243 /*! \note detector will not visit child block recursively, so it will stop here */
244 std::unordered_map<const VarNode*, PrimExpr> vmap;
245 for (size_t i = 0; i < op->block->iter_vars.size(); ++i) {
246 vmap[op->block->iter_vars[i]->var.get()] = op->iter_values[i];
247 }
248 for (const auto& read : op->block->reads) {
249 std::vector<arith::IntSet> relaxed_region;
250 for (const auto& range : read->region) {
251 relaxed_region.push_back(
252 arith::EvalSet(arith::IntSet::FromRange(Range::FromMinExtent(
253 Substitute(range->min, vmap), Substitute(range->extent, vmap))),
254 dom_map_));
255 }
256 Update(&read_buffers_, &read_regions_, read->buffer, relaxed_region);
257 }
258 for (const auto& write : op->block->writes) {
259 std::vector<arith::IntSet> relaxed_region;
260 for (const auto& range : write->region) {
261 relaxed_region.push_back(
262 arith::EvalSet(arith::IntSet::FromRange(Range::FromMinExtent(
263 Substitute(range->min, vmap), Substitute(range->extent, vmap))),
264 dom_map_));
265 }
266 Update(&writes_buffers_, &write_regions_, write->buffer, relaxed_region);
267 }
268}
269
270std::vector<arith::IntSet> BlockReadWriteDetector::ConvertMatchedRegion(
271 const MatchBufferRegion& match_buffer, const std::vector<arith::IntSet>& int_sets) const {
272 const Buffer& buffer = match_buffer->buffer;
273
274 Region region;
275 region.reserve(int_sets.size());
276 ICHECK_EQ(buffer->shape.size(), int_sets.size());
277 for (size_t i = 0; i < int_sets.size(); ++i) {
278 const tvm::arith::IntSet& int_set = int_sets[i];
279 region.push_back(int_set.CoverRange(Range::FromMinExtent(0, buffer->shape[i])));
280 }
281
282 region = ConvertRegion(match_buffer, region);
283
284 std::vector<arith::IntSet> result;
285 result.reserve(region.size());
286 for (const Range& range : region) {
287 result.push_back(arith::EvalSet(range, dom_map_));
288 }
289 return result;
290}
291
292void BlockReadWriteDetector::Update(std::vector<Buffer>* buffers,
293 std::vector<std::vector<arith::IntSet>>* regions, Buffer buffer,
294 std::vector<arith::IntSet> region) {
295 if (buffer_var_map_.find(buffer->data) == buffer_var_map_.end()) return;
296 // Handle match_buffer remap
297 auto it = match_buffers_.find(buffer->data.get());
298 if (it != match_buffers_.end()) {
299 const MatchBufferRegion& match_buffer = it->second;
300 buffer = match_buffer->source->buffer;
301 region = ConvertMatchedRegion(match_buffer, std::move(region));
302 }
303 ICHECK_EQ(buffers->size(), regions->size())
304 << " Expected the buffer and regions to have the same size ";
305 for (size_t i = 0; i < regions->size(); ++i) {
306 if ((*buffers)[i].same_as(buffer)) {
307 ICHECK_EQ((*regions)[i].size(), region.size()) << "Inconsistent buffer dimension";
308 for (size_t j = 0; j < region.size(); ++j) {
309 (*regions)[i][j] = arith::Union({(*regions)[i][j], region[j]});
310 }
311 return;
312 }
313 }
314 buffers->push_back(std::move(buffer));
315 regions->push_back(std::move(region));
316}
317
318Array<BufferRegion> BlockReadWriteDetector::CollectRegions(
319 const std::vector<Buffer>& buffers, const std::vector<std::vector<tvm::arith::IntSet>>& regions,
320 const std::unordered_set<const BufferNode*>* excluded_buffers) {
321 ICHECK_EQ(buffers.size(), regions.size());
322 Array<BufferRegion> res;
323 res.reserve(buffers.size());
324 for (size_t i = 0; i < regions.size(); ++i) {
325 if (excluded_buffers != nullptr && excluded_buffers->count(buffers[i].get())) {
326 continue;
327 }
328 Array<Range> region;
329 region.reserve(regions[i].size());
330 ICHECK_EQ(buffers[i]->shape.size(), regions[i].size());
331 for (size_t j = 0; j < regions[i].size(); j++) {
332 const tvm::arith::IntSet& range = regions[i][j];
333 region.push_back(range.CoverRange(Range::FromMinExtent(0, buffers[i]->shape[j])));
334 }
335 res.push_back(BufferRegion(buffers[i], region));
336 }
337 return res;
338}
339
340void BlockReadWriteDetector::UpdateOpaque(const Var& buffer_var) {
341 auto it = buffer_var_map_.find(buffer_var);
342 if (it != buffer_var_map_.end()) {
343 const Buffer& buffer = (*it).second;
344 const BufferRegion buffer_region = BufferRegion::FullRegion(buffer);
345 const Region& region = buffer_region->region;
346 std::vector<arith::IntSet> int_set;
347 int_set.reserve(region.size());
348 for (const Range& range : region) {
349 int_set.push_back(arith::EvalSet(range, dom_map_));
350 }
351 Update(&opaque_buffers_, &opaque_regions_, buffer, int_set);
352 }
353}
354
355Array<Array<BufferRegion>> GetBlockAccessRegion(const Block& block,
356 const Map<Var, Buffer>& buffer_var_map) {
357 BlockReadWriteDetector detector(buffer_var_map);
358 detector(block);
359 Array<BufferRegion> writes = detector.CollectWrites();
360 std::unordered_set<const BufferNode*> excluded_buffers;
361 // exclude write buffers from read regions for reductions if init block is defined.
362 if (block->init.defined()) {
363 for (const BufferRegion& write_access : writes) {
364 excluded_buffers.insert(write_access->buffer.get());
365 }
366 }
367 Array<BufferRegion> reads = detector.CollectReads(&excluded_buffers);
368 Array<BufferRegion> opaques = detector.CollectOpaques();
369 return {reads, writes, opaques};
370}
371
372Array<Array<BufferRegion>> GetBlockReadWriteRegion(const Block& block,
373 const Map<Var, Buffer>& buffer_var_map) {
374 BlockReadWriteDetector detector(buffer_var_map);
375 detector(block);
376 Array<BufferRegion> opaques = detector.CollectOpaques();
377 std::unordered_set<const BufferNode*> excluded_buffers;
378 for (const BufferRegion& opaque_access : opaques) {
379 excluded_buffers.insert(opaque_access->buffer.get());
380 }
381 Array<BufferRegion> writes = detector.CollectWrites(&excluded_buffers);
382 if (block->init.defined()) {
383 for (const BufferRegion& write_access : writes) {
384 excluded_buffers.insert(write_access->buffer.get());
385 }
386 }
387 Array<BufferRegion> reads = detector.CollectReads(&excluded_buffers);
388 for (const BufferRegion& opaque_access : opaques) {
389 reads.push_back(opaque_access);
390 writes.push_back(opaque_access);
391 }
392 return {reads, writes};
393}
394
395TVM_REGISTER_GLOBAL("tir.analysis.GetBlockAccessRegion").set_body_typed(GetBlockAccessRegion);
396TVM_REGISTER_GLOBAL("tir.analysis.GetBlockReadWriteRegion").set_body_typed(GetBlockReadWriteRegion);
397
398} // namespace tir
399} // namespace tvm
400