1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include <functional> |
20 | |
21 | #include "../ir_comparator.h" |
22 | #include "../utils.h" |
23 | |
24 | namespace tvm { |
25 | namespace tir { |
26 | |
27 | namespace { |
28 | |
29 | struct RollingBufferInfo { |
30 | Buffer old_buffer; |
31 | Buffer new_buffer; |
32 | int rolling_axis; |
33 | PrimExpr rolling_extent; |
34 | std::vector<int> axis_overlaps; |
35 | std::vector<Optional<Var>> axis_iter_vars; |
36 | /*! \brief The map used for ScheduleStateNode::Replace. */ |
37 | Map<Block, Block> block_reuse; |
38 | }; |
39 | |
40 | BufferRegion GetRelaxedBufferRegion(const BlockRealize& realize, const BufferRegion& buffer_region, |
41 | const Map<Var, arith::IntSet>& dom_map) { |
42 | Array<arith::IntSet> relaxed_intsets = |
43 | arith::EvalSet(Substitute(buffer_region->region, GetBindings(realize)), dom_map); |
44 | Region relaxed_region; |
45 | relaxed_region.reserve(relaxed_intsets.size()); |
46 | for (size_t i = 0; i < relaxed_intsets.size(); ++i) { |
47 | relaxed_region.push_back( |
48 | relaxed_intsets[i].CoverRange(Range::FromMinExtent(0, buffer_region->buffer->shape[i]))); |
49 | } |
50 | return BufferRegion(buffer_region->buffer, relaxed_region); |
51 | } |
52 | |
53 | class RollingBufferDependencyError : public ScheduleError { |
54 | public: |
55 | explicit RollingBufferDependencyError(IRModule mod, Block block) |
56 | : mod_(mod), block_(std::move(block)) {} |
57 | |
58 | String FastErrorString() const final { |
59 | return "ScheduleError: The target block is required to have only RAW dependencies" ; |
60 | } |
61 | |
62 | String DetailRenderTemplate() const final { |
63 | return "The target block {0} is required to have only RAW dependencies" ; |
64 | } |
65 | |
66 | IRModule mod() const final { return mod_; } |
67 | Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } |
68 | |
69 | /*! |
70 | * \brief Check if the block has only RAW dependencies. |
71 | * \param self The schedule state |
72 | * \param block_sref The sref of the block to be checked |
73 | * \param scope_root_sref The sref of the scope root |
74 | * \throw ScheduleError if the block has WAW or WAR dependency. |
75 | */ |
76 | static void Check(const ScheduleState& self, const StmtSRef& block_sref, |
77 | const StmtSRef& scope_root_sref) { |
78 | BlockScope scope = self->GetBlockScope(scope_root_sref); |
79 | for (const Dependency& producers : scope->GetDepsByDst(block_sref)) { |
80 | if (!(producers->kind == DepKind::kRAW)) { |
81 | const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); |
82 | throw RollingBufferDependencyError(self->mod, GetRef<Block>(block)); |
83 | } |
84 | } |
85 | for (const Dependency& consumers : scope->GetDepsBySrc(block_sref)) { |
86 | if (!(consumers->kind == DepKind::kRAW)) { |
87 | const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); |
88 | throw RollingBufferDependencyError(self->mod, GetRef<Block>(block)); |
89 | } |
90 | } |
91 | } |
92 | |
93 | private: |
94 | IRModule mod_; |
95 | Block block_; |
96 | }; |
97 | |
98 | class RollingBufferMatchError : public ScheduleError { |
99 | public: |
100 | RollingBufferMatchError(IRModule mod, Block block, BufferRegion buffer_region) |
101 | : mod_(mod), block_(block), buffer_region_(buffer_region) {} |
102 | String FastErrorString() const final { |
103 | return "ScheduleError: rolling_buffer expect the buffer region to have at least one dimention" |
104 | "matching the rolling pattern such as: hh.outer * stride + hh.inner" ; |
105 | } |
106 | String DetailRenderTemplate() const final { |
107 | std::ostringstream os; |
108 | os << "The target buffer " << buffer_region_->buffer->name << " with region " |
109 | << buffer_region_->region |
110 | << " should have at least one dimension range that matches a rolling pattern " |
111 | "such as hh.outer * stride + hh.inner. " ; |
112 | return os.str(); |
113 | } |
114 | |
115 | IRModule mod() const final { return mod_; } |
116 | Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } |
117 | |
118 | private: |
119 | IRModule mod_; |
120 | Block block_; |
121 | BufferRegion buffer_region_; |
122 | }; |
123 | |
124 | class RollingBufferInsertionError : public ScheduleError { |
125 | public: |
126 | RollingBufferInsertionError(IRModule mod, Buffer buffer, Block block) |
127 | : mod_(mod), buffer_(std::move(buffer)), block_(block) {} |
128 | String FastErrorString() const final { |
129 | return "ScheduleError: rolling_buffer injection is invalid, the lca of the access " |
130 | "location of the target buffer is not a for loop. " ; |
131 | } |
132 | |
133 | String DetailRenderTemplate() const final { |
134 | std::ostringstream os; |
135 | os << "rolling_buffer injection is invalid. The block {0} should be tiled so that " |
136 | << "the lca of the access location of the target buffer " << buffer_->name |
137 | << " is a for loop. " ; |
138 | return os.str(); |
139 | } |
140 | IRModule mod() const final { return mod_; } |
141 | Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } |
142 | |
143 | private: |
144 | IRModule mod_; |
145 | Buffer buffer_; |
146 | Block block_; |
147 | }; |
148 | |
149 | class RollingBufferInfoCollector { |
150 | public: |
151 | static RollingBufferInfo CheckAndGetRollingBufferInfo(const IRModule& mod, |
152 | const StmtSRef& block_sref, |
153 | const BufferRegion& buffer_region) { |
154 | RollingBufferInfoCollector collector; |
155 | if (!collector.MatchRollingBuffer(block_sref, buffer_region)) { |
156 | const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); |
157 | throw RollingBufferMatchError(mod, GetRef<Block>(block), buffer_region); |
158 | } |
159 | return collector.info_; |
160 | } |
161 | |
162 | private: |
163 | bool MatchRollingBuffer(const StmtSRef& block_sref, const BufferRegion& buffer_region) { |
164 | const Buffer& buffer = buffer_region->buffer; |
165 | const Region& region = buffer_region->region; |
166 | |
167 | std::vector<Optional<Var>> bound_iter_vars; |
168 | std::vector<int> bound_overlaps; |
169 | |
170 | arith::PVar<Var> p_var; |
171 | arith::PVar<IntImm> p_stride, p_divisor; |
172 | for (auto bound : region) { |
173 | auto stride = 0; |
174 | auto divisor = 1; |
175 | |
176 | Optional<Var> iter_var; |
177 | if (floordiv((p_var * p_stride), p_divisor).Match(bound->min)) { |
178 | // Handle the case of fractional strides |
179 | // They take this form: floordiv(hh.outer, 2) |
180 | // Strip the floordiv and keep track of the divisor |
181 | iter_var = p_var.Eval(); |
182 | divisor = p_divisor.Eval()->value; |
183 | stride = std::ceil(static_cast<float>(p_stride.Eval()->value) / divisor); |
184 | } else if ((p_var * p_stride).Match(bound->min)) { |
185 | // The bound is the iter var multiplied by the stride |
186 | iter_var = p_var.Eval(); |
187 | stride = p_stride.Eval()->value; |
188 | } else if (p_var.Match(bound->min)) { |
189 | // If the bound is just a Var, that implies the stride is 1 |
190 | iter_var = p_var.Eval(); |
191 | stride = 1; |
192 | } else if (is_const_int(bound->min)) { |
193 | // If the bound is an int, we can't roll over it |
194 | iter_var = NullOpt; |
195 | } else { |
196 | // If all of the above matches fail, we're in unknown behaviour |
197 | return false; |
198 | } |
199 | auto bound_overlap = 0; |
200 | if (iter_var.defined()) { |
201 | auto extent = Downcast<IntImm>(bound->extent)->value; |
202 | bound_overlap = extent - stride; |
203 | // Since Pass CompactBufferAllocation will be responsible for compacting the buffer |
204 | // allocation region, there is no need to roll over the axis where the overlap is not |
205 | // positive, so reset iter_var to NullOpt. |
206 | if (bound_overlap <= 0) { |
207 | iter_var = NullOpt; |
208 | } |
209 | } |
210 | bound_iter_vars.push_back(iter_var); |
211 | bound_overlaps.push_back(bound_overlap); |
212 | } |
213 | |
214 | Array<StmtSRef> loop_srefs = GetLoops(block_sref); |
215 | // Pick the outermost iter_var that's mentioned in the bounds |
216 | // to be the rolling axis |
217 | Optional<Var> roll_iter_var; |
218 | int roll_axis; |
219 | for (const tir::StmtSRef& loop_sref : loop_srefs) { |
220 | auto loop_var = loop_sref->StmtAs<ForNode>()->loop_var; |
221 | |
222 | auto it{std::find_if(bound_iter_vars.begin(), bound_iter_vars.end(), [&](Optional<Var> var) { |
223 | return var && (var.get() == loop_var.get()); |
224 | })}; |
225 | if (it != bound_iter_vars.end()) { |
226 | auto i = std::distance(bound_iter_vars.begin(), it); |
227 | roll_iter_var = loop_var; |
228 | roll_axis = i; |
229 | break; |
230 | } |
231 | } |
232 | |
233 | if (!roll_iter_var.defined()) { |
234 | return false; |
235 | } |
236 | Array<PrimExpr> new_shape = buffer->shape; |
237 | new_shape.Set(roll_axis, region[roll_axis]->extent); |
238 | Buffer new_buffer = buffer; |
239 | new_buffer.CopyOnWrite()->shape = new_shape; |
240 | |
241 | info_.old_buffer = buffer; |
242 | info_.new_buffer = new_buffer; |
243 | info_.rolling_axis = roll_axis; |
244 | info_.rolling_extent = region[roll_axis]->extent; |
245 | info_.axis_overlaps = bound_overlaps; |
246 | info_.axis_iter_vars = bound_iter_vars; |
247 | |
248 | return true; |
249 | } |
250 | |
251 | RollingBufferInfo info_; |
252 | }; |
253 | |
254 | class RollingBufferRewriter : public StmtExprMutator { |
255 | public: |
256 | static Stmt Rewrite(const StmtSRef& scope_sref, RollingBufferInfo* info) { |
257 | RollingBufferRewriter rewriter(scope_sref, info); |
258 | return rewriter(GetRef<Stmt>(scope_sref->stmt)); |
259 | } |
260 | |
261 | private: |
262 | explicit RollingBufferRewriter(const StmtSRef& scope_sref, RollingBufferInfo* info) |
263 | : scope_sref_(scope_sref), info_(info) {} |
264 | |
265 | void RewriteAccessRegion(Array<BufferRegion>* old_access_regions, |
266 | const Array<BufferRegion>& infered_access_regions) { |
267 | auto fmutate = [this, &infered_access_regions](const BufferRegion& buffer_region) { |
268 | if (buffer_region->buffer.same_as(info_->old_buffer)) { |
269 | ICHECK(infered_access_regions.size() == 1); |
270 | return infered_access_regions[0]; |
271 | } |
272 | return buffer_region; |
273 | }; |
274 | (*old_access_regions).MutateByApply(fmutate); |
275 | } |
276 | |
277 | void RewriteBufferAccess(Buffer* buffer, Array<PrimExpr>* indices) const { |
278 | Array<PrimExpr> new_indices; |
279 | new_indices.reserve(indices->size()); |
280 | // First modify the access indices to use modulo arithmetic |
281 | // for the rolling axis |
282 | for (size_t i = 0; i < indices->size(); ++i) { |
283 | if (static_cast<int>(i) == info_->rolling_axis) { |
284 | new_indices.push_back(FloorMod((*indices)[i], info_->rolling_extent)); |
285 | } else { |
286 | new_indices.push_back((*indices)[i]); |
287 | } |
288 | } |
289 | // Replace the accessed buffer with the new buffer. |
290 | *buffer = info_->new_buffer; |
291 | *indices = std::move(new_indices); |
292 | } |
293 | |
294 | Stmt VisitStmt_(const BlockNode* block) final { |
295 | Block old_stmt = GetRef<Block>(block); |
296 | Block stmt = Downcast<Block>(StmtExprMutator::VisitStmt_(block)); |
297 | BlockNode* n = stmt.CopyOnWrite(); |
298 | if (block == scope_sref_->stmt) { |
299 | Array<Buffer> new_alloc_buffers; |
300 | for (const Buffer& buffer : stmt->alloc_buffers) { |
301 | if (buffer != info_->old_buffer) { |
302 | new_alloc_buffers.push_back(buffer); |
303 | } else { |
304 | new_alloc_buffers.push_back(info_->new_buffer); |
305 | } |
306 | } |
307 | n->alloc_buffers = std::move(new_alloc_buffers); |
308 | } else { |
309 | Array<IterVar> new_iter_vars; |
310 | for (size_t i = 0; i < stmt->iter_vars.size(); ++i) { |
311 | auto old_iter_var = stmt->iter_vars[i]; |
312 | if (static_cast<int>(i) == info_->rolling_axis) { |
313 | // All inner loops of the rolling axis has a loop carried dependency |
314 | // (i.e. each iteration calculation of the rolling axis depends on |
315 | // the calculation results of all the historical iterations of inner loops), |
316 | // so annotate the iteration type of the rolling axis as 'opaque', |
317 | // avoid the iterative range of its inner loop from being compressed |
318 | // during lowering phase. |
319 | IterVar new_iter_var = |
320 | IterVar(old_iter_var->dom, old_iter_var->var, IterVarType::kOpaque); |
321 | new_iter_vars.push_back(new_iter_var); |
322 | } else { |
323 | new_iter_vars.push_back(old_iter_var); |
324 | } |
325 | } |
326 | Map<Var, Buffer> buffer_data_to_buffer = {{info_->new_buffer->data, info_->new_buffer}}; |
327 | auto infered_access_regions = GetBlockReadWriteRegion(stmt, buffer_data_to_buffer); |
328 | |
329 | n->iter_vars = std::move(new_iter_vars); |
330 | RewriteAccessRegion(&n->reads, infered_access_regions[0]); |
331 | RewriteAccessRegion(&n->writes, infered_access_regions[1]); |
332 | } |
333 | info_->block_reuse.Set(old_stmt, stmt); |
334 | return std::move(stmt); |
335 | } |
336 | |
337 | Stmt VisitStmt_(const BlockRealizeNode* realize) final { |
338 | BlockRealize stmt = Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(realize)); |
339 | // Append block predicate to avoid recomputing elements. |
340 | if (rewrite_block_predicate_) { |
341 | rewrite_block_predicate_ = false; |
342 | PrimExpr condition = stmt->predicate; |
343 | for (size_t i = 0; i < info_->axis_iter_vars.size(); ++i) { |
344 | auto iter_var = info_->axis_iter_vars[i]; |
345 | if (iter_var && info_->axis_overlaps[i] > 0) { |
346 | Var var = iter_var.value(); |
347 | const Map<Var, arith::IntSet> dmap = {std::make_pair(var, arith::IntSet::Interval(0, 0))}; |
348 | auto iter_value = realize->iter_values[i]; |
349 | arith::Analyzer analyzer; |
350 | auto term_2 = analyzer.int_set(iter_value, dmap).min(); |
351 | condition = analyzer.Simplify( |
352 | And(condition, Or(LT(var, 1), GE(term_2, info_->axis_overlaps[i])))); |
353 | } |
354 | } |
355 | BlockRealizeNode* n = stmt.CopyOnWrite(); |
356 | n->predicate = condition; |
357 | } |
358 | return std::move(stmt); |
359 | } |
360 | |
361 | Stmt VisitStmt_(const BufferStoreNode* op) final { |
362 | BufferStore stmt = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); |
363 | if (stmt->buffer.same_as(info_->old_buffer)) { |
364 | BufferStoreNode* n = stmt.CopyOnWrite(); |
365 | RewriteBufferAccess(&n->buffer, &n->indices); |
366 | // Need to add predicate to the current block to avoid recomputing elements. |
367 | rewrite_block_predicate_ = true; |
368 | } |
369 | return std::move(stmt); |
370 | } |
371 | |
372 | PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
373 | BufferLoad stmt = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); |
374 | if (stmt->buffer.same_as(info_->old_buffer)) { |
375 | BufferLoadNode* n = stmt.CopyOnWrite(); |
376 | RewriteBufferAccess(&n->buffer, &n->indices); |
377 | } |
378 | return std::move(stmt); |
379 | } |
380 | |
381 | private: |
382 | const StmtSRef& scope_sref_; |
383 | RollingBufferInfo* info_; |
384 | bool rewrite_block_predicate_ = false; |
385 | }; |
386 | |
387 | } // namespace |
388 | |
389 | void RollingBuffer(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index) { |
390 | /*! |
391 | * Check |
392 | * - The block is not an output block. |
393 | * - The block has only RAW dependencies. |
394 | * - The block is tiled and there is access overlap between adjacent tiles. |
395 | * Mutate |
396 | * - Select the outermost rollable axis appeared in the block's loop nest |
397 | * as the 'rolling axis', trim the target buffer from the rolling axis. |
398 | * - Use modulo arithmetic to modify the target buffer's read and load |
399 | * indices to circularize the buffer along the rolling dimension. |
400 | * - Append block predicate to avoid recomputing overlapping elements. |
401 | */ |
402 | Map<Var, arith::IntSet> dom_map; |
403 | const BlockRealize& realize = GetBlockRealize(self, block_sref); |
404 | const Block& block = realize->block; |
405 | |
406 | // Step 1. Checking index, getting the target buffer region and the parent scope. |
407 | const BufferRegion& buffer_region = |
408 | GetNthAccessBufferRegion(self, block, write_buffer_index, BufferIndexType::kWrite); |
409 | StmtSRef scope_root_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); |
410 | // Step 2. Check if the target block is not an output block and has only RAW dependencies. |
411 | CheckNotOutputBlock(self, block_sref, scope_root_sref); |
412 | RollingBufferDependencyError::Check(self, block_sref, scope_root_sref); |
413 | |
414 | // Step 3. Find the lca of the access location of the target buffer and relax the buffer |
415 | Array<StmtSRef> loop_srefs = GetLoops(block_sref); |
416 | Array<StmtSRef> consumers_sref = GetConsumers(self, block_sref); |
417 | consumers_sref.push_back(block_sref); |
418 | StmtSRef lca = GetSRefLowestCommonAncestor(consumers_sref); |
419 | if (!lca->StmtAs<ForNode>()) { |
420 | throw RollingBufferInsertionError(self->mod, buffer_region->buffer, block); |
421 | } |
422 | |
423 | for (auto it = loop_srefs.rbegin(); it != loop_srefs.rend(); ++it) { |
424 | auto stmt = *it; |
425 | // Stop at the lca of all the rolling_buffer access points; |
426 | if (stmt == lca) { |
427 | break; |
428 | } |
429 | For cur_loop = GetRef<For>(stmt->StmtAs<ForNode>()); |
430 | Range range = Range::FromMinExtent(cur_loop->min, cur_loop->extent); |
431 | dom_map.Set(cur_loop->loop_var, arith::IntSet::FromRange(range)); |
432 | } |
433 | BufferRegion relaxed_region = GetRelaxedBufferRegion(realize, buffer_region, dom_map); |
434 | |
435 | // Step 4. Find a valid rolling axis and collect bound overlaps on the target buffer. |
436 | RollingBufferInfo info = RollingBufferInfoCollector::CheckAndGetRollingBufferInfo( |
437 | self->mod, block_sref, relaxed_region); |
438 | // Step 5. Mutate IR to apply rolling access pattern. |
439 | Stmt new_scope_root = RollingBufferRewriter::Rewrite(scope_root_sref, &info); |
440 | |
441 | // Step 6. Update schedule states |
442 | self->Replace(scope_root_sref, new_scope_root, info.block_reuse); |
443 | // Step 7. Regenerate block info from the root block, because `region_cover` for the target block |
444 | // and `stage_pipeline` for the root block are no longer satisfied after rolling buffer injection. |
445 | self->UpdateScopeBlockInfo(tir::GetBlockRealize(self, self->stmt2ref.at(new_scope_root.get()))); |
446 | } |
447 | |
448 | struct RollingBufferTraits : public UnpackedInstTraits<RollingBufferTraits> { |
449 | static constexpr const char* kName = "RollingBuffer" ; |
450 | static constexpr bool kIsPure = false; |
451 | |
452 | private: |
453 | static constexpr size_t kNumInputs = 1; |
454 | static constexpr size_t kNumAttrs = 1; |
455 | static constexpr size_t kNumDecisions = 0; |
456 | |
457 | static void UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer write_buffer_index) { |
458 | return sch->RollingBuffer(block, write_buffer_index.IntValue()); |
459 | } |
460 | |
461 | static String UnpackedAsPython(Array<String> outputs, String block, Integer write_buffer_index) { |
462 | PythonAPICall py("rolling_buffer" ); |
463 | py.Input("block" , block); |
464 | py.Input("write_buffer_index" , write_buffer_index); |
465 | return py.Str(); |
466 | } |
467 | |
468 | template <typename> |
469 | friend struct ::tvm::tir::UnpackedInstTraits; |
470 | }; |
471 | |
472 | TVM_REGISTER_INST_KIND_TRAITS(RollingBufferTraits); |
473 | } // namespace tir |
474 | } // namespace tvm |
475 | |