1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include <functional>
20
21#include "../ir_comparator.h"
22#include "../utils.h"
23
24namespace tvm {
25namespace tir {
26
27namespace {
28
29struct RollingBufferInfo {
30 Buffer old_buffer;
31 Buffer new_buffer;
32 int rolling_axis;
33 PrimExpr rolling_extent;
34 std::vector<int> axis_overlaps;
35 std::vector<Optional<Var>> axis_iter_vars;
36 /*! \brief The map used for ScheduleStateNode::Replace. */
37 Map<Block, Block> block_reuse;
38};
39
40BufferRegion GetRelaxedBufferRegion(const BlockRealize& realize, const BufferRegion& buffer_region,
41 const Map<Var, arith::IntSet>& dom_map) {
42 Array<arith::IntSet> relaxed_intsets =
43 arith::EvalSet(Substitute(buffer_region->region, GetBindings(realize)), dom_map);
44 Region relaxed_region;
45 relaxed_region.reserve(relaxed_intsets.size());
46 for (size_t i = 0; i < relaxed_intsets.size(); ++i) {
47 relaxed_region.push_back(
48 relaxed_intsets[i].CoverRange(Range::FromMinExtent(0, buffer_region->buffer->shape[i])));
49 }
50 return BufferRegion(buffer_region->buffer, relaxed_region);
51}
52
53class RollingBufferDependencyError : public ScheduleError {
54 public:
55 explicit RollingBufferDependencyError(IRModule mod, Block block)
56 : mod_(mod), block_(std::move(block)) {}
57
58 String FastErrorString() const final {
59 return "ScheduleError: The target block is required to have only RAW dependencies";
60 }
61
62 String DetailRenderTemplate() const final {
63 return "The target block {0} is required to have only RAW dependencies";
64 }
65
66 IRModule mod() const final { return mod_; }
67 Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
68
69 /*!
70 * \brief Check if the block has only RAW dependencies.
71 * \param self The schedule state
72 * \param block_sref The sref of the block to be checked
73 * \param scope_root_sref The sref of the scope root
74 * \throw ScheduleError if the block has WAW or WAR dependency.
75 */
76 static void Check(const ScheduleState& self, const StmtSRef& block_sref,
77 const StmtSRef& scope_root_sref) {
78 BlockScope scope = self->GetBlockScope(scope_root_sref);
79 for (const Dependency& producers : scope->GetDepsByDst(block_sref)) {
80 if (!(producers->kind == DepKind::kRAW)) {
81 const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
82 throw RollingBufferDependencyError(self->mod, GetRef<Block>(block));
83 }
84 }
85 for (const Dependency& consumers : scope->GetDepsBySrc(block_sref)) {
86 if (!(consumers->kind == DepKind::kRAW)) {
87 const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
88 throw RollingBufferDependencyError(self->mod, GetRef<Block>(block));
89 }
90 }
91 }
92
93 private:
94 IRModule mod_;
95 Block block_;
96};
97
98class RollingBufferMatchError : public ScheduleError {
99 public:
100 RollingBufferMatchError(IRModule mod, Block block, BufferRegion buffer_region)
101 : mod_(mod), block_(block), buffer_region_(buffer_region) {}
102 String FastErrorString() const final {
103 return "ScheduleError: rolling_buffer expect the buffer region to have at least one dimention"
104 "matching the rolling pattern such as: hh.outer * stride + hh.inner";
105 }
106 String DetailRenderTemplate() const final {
107 std::ostringstream os;
108 os << "The target buffer " << buffer_region_->buffer->name << " with region "
109 << buffer_region_->region
110 << " should have at least one dimension range that matches a rolling pattern "
111 "such as hh.outer * stride + hh.inner. ";
112 return os.str();
113 }
114
115 IRModule mod() const final { return mod_; }
116 Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
117
118 private:
119 IRModule mod_;
120 Block block_;
121 BufferRegion buffer_region_;
122};
123
124class RollingBufferInsertionError : public ScheduleError {
125 public:
126 RollingBufferInsertionError(IRModule mod, Buffer buffer, Block block)
127 : mod_(mod), buffer_(std::move(buffer)), block_(block) {}
128 String FastErrorString() const final {
129 return "ScheduleError: rolling_buffer injection is invalid, the lca of the access "
130 "location of the target buffer is not a for loop. ";
131 }
132
133 String DetailRenderTemplate() const final {
134 std::ostringstream os;
135 os << "rolling_buffer injection is invalid. The block {0} should be tiled so that "
136 << "the lca of the access location of the target buffer " << buffer_->name
137 << " is a for loop. ";
138 return os.str();
139 }
140 IRModule mod() const final { return mod_; }
141 Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
142
143 private:
144 IRModule mod_;
145 Buffer buffer_;
146 Block block_;
147};
148
149class RollingBufferInfoCollector {
150 public:
151 static RollingBufferInfo CheckAndGetRollingBufferInfo(const IRModule& mod,
152 const StmtSRef& block_sref,
153 const BufferRegion& buffer_region) {
154 RollingBufferInfoCollector collector;
155 if (!collector.MatchRollingBuffer(block_sref, buffer_region)) {
156 const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
157 throw RollingBufferMatchError(mod, GetRef<Block>(block), buffer_region);
158 }
159 return collector.info_;
160 }
161
162 private:
163 bool MatchRollingBuffer(const StmtSRef& block_sref, const BufferRegion& buffer_region) {
164 const Buffer& buffer = buffer_region->buffer;
165 const Region& region = buffer_region->region;
166
167 std::vector<Optional<Var>> bound_iter_vars;
168 std::vector<int> bound_overlaps;
169
170 arith::PVar<Var> p_var;
171 arith::PVar<IntImm> p_stride, p_divisor;
172 for (auto bound : region) {
173 auto stride = 0;
174 auto divisor = 1;
175
176 Optional<Var> iter_var;
177 if (floordiv((p_var * p_stride), p_divisor).Match(bound->min)) {
178 // Handle the case of fractional strides
179 // They take this form: floordiv(hh.outer, 2)
180 // Strip the floordiv and keep track of the divisor
181 iter_var = p_var.Eval();
182 divisor = p_divisor.Eval()->value;
183 stride = std::ceil(static_cast<float>(p_stride.Eval()->value) / divisor);
184 } else if ((p_var * p_stride).Match(bound->min)) {
185 // The bound is the iter var multiplied by the stride
186 iter_var = p_var.Eval();
187 stride = p_stride.Eval()->value;
188 } else if (p_var.Match(bound->min)) {
189 // If the bound is just a Var, that implies the stride is 1
190 iter_var = p_var.Eval();
191 stride = 1;
192 } else if (is_const_int(bound->min)) {
193 // If the bound is an int, we can't roll over it
194 iter_var = NullOpt;
195 } else {
196 // If all of the above matches fail, we're in unknown behaviour
197 return false;
198 }
199 auto bound_overlap = 0;
200 if (iter_var.defined()) {
201 auto extent = Downcast<IntImm>(bound->extent)->value;
202 bound_overlap = extent - stride;
203 // Since Pass CompactBufferAllocation will be responsible for compacting the buffer
204 // allocation region, there is no need to roll over the axis where the overlap is not
205 // positive, so reset iter_var to NullOpt.
206 if (bound_overlap <= 0) {
207 iter_var = NullOpt;
208 }
209 }
210 bound_iter_vars.push_back(iter_var);
211 bound_overlaps.push_back(bound_overlap);
212 }
213
214 Array<StmtSRef> loop_srefs = GetLoops(block_sref);
215 // Pick the outermost iter_var that's mentioned in the bounds
216 // to be the rolling axis
217 Optional<Var> roll_iter_var;
218 int roll_axis;
219 for (const tir::StmtSRef& loop_sref : loop_srefs) {
220 auto loop_var = loop_sref->StmtAs<ForNode>()->loop_var;
221
222 auto it{std::find_if(bound_iter_vars.begin(), bound_iter_vars.end(), [&](Optional<Var> var) {
223 return var && (var.get() == loop_var.get());
224 })};
225 if (it != bound_iter_vars.end()) {
226 auto i = std::distance(bound_iter_vars.begin(), it);
227 roll_iter_var = loop_var;
228 roll_axis = i;
229 break;
230 }
231 }
232
233 if (!roll_iter_var.defined()) {
234 return false;
235 }
236 Array<PrimExpr> new_shape = buffer->shape;
237 new_shape.Set(roll_axis, region[roll_axis]->extent);
238 Buffer new_buffer = buffer;
239 new_buffer.CopyOnWrite()->shape = new_shape;
240
241 info_.old_buffer = buffer;
242 info_.new_buffer = new_buffer;
243 info_.rolling_axis = roll_axis;
244 info_.rolling_extent = region[roll_axis]->extent;
245 info_.axis_overlaps = bound_overlaps;
246 info_.axis_iter_vars = bound_iter_vars;
247
248 return true;
249 }
250
251 RollingBufferInfo info_;
252};
253
254class RollingBufferRewriter : public StmtExprMutator {
255 public:
256 static Stmt Rewrite(const StmtSRef& scope_sref, RollingBufferInfo* info) {
257 RollingBufferRewriter rewriter(scope_sref, info);
258 return rewriter(GetRef<Stmt>(scope_sref->stmt));
259 }
260
261 private:
262 explicit RollingBufferRewriter(const StmtSRef& scope_sref, RollingBufferInfo* info)
263 : scope_sref_(scope_sref), info_(info) {}
264
265 void RewriteAccessRegion(Array<BufferRegion>* old_access_regions,
266 const Array<BufferRegion>& infered_access_regions) {
267 auto fmutate = [this, &infered_access_regions](const BufferRegion& buffer_region) {
268 if (buffer_region->buffer.same_as(info_->old_buffer)) {
269 ICHECK(infered_access_regions.size() == 1);
270 return infered_access_regions[0];
271 }
272 return buffer_region;
273 };
274 (*old_access_regions).MutateByApply(fmutate);
275 }
276
277 void RewriteBufferAccess(Buffer* buffer, Array<PrimExpr>* indices) const {
278 Array<PrimExpr> new_indices;
279 new_indices.reserve(indices->size());
280 // First modify the access indices to use modulo arithmetic
281 // for the rolling axis
282 for (size_t i = 0; i < indices->size(); ++i) {
283 if (static_cast<int>(i) == info_->rolling_axis) {
284 new_indices.push_back(FloorMod((*indices)[i], info_->rolling_extent));
285 } else {
286 new_indices.push_back((*indices)[i]);
287 }
288 }
289 // Replace the accessed buffer with the new buffer.
290 *buffer = info_->new_buffer;
291 *indices = std::move(new_indices);
292 }
293
294 Stmt VisitStmt_(const BlockNode* block) final {
295 Block old_stmt = GetRef<Block>(block);
296 Block stmt = Downcast<Block>(StmtExprMutator::VisitStmt_(block));
297 BlockNode* n = stmt.CopyOnWrite();
298 if (block == scope_sref_->stmt) {
299 Array<Buffer> new_alloc_buffers;
300 for (const Buffer& buffer : stmt->alloc_buffers) {
301 if (buffer != info_->old_buffer) {
302 new_alloc_buffers.push_back(buffer);
303 } else {
304 new_alloc_buffers.push_back(info_->new_buffer);
305 }
306 }
307 n->alloc_buffers = std::move(new_alloc_buffers);
308 } else {
309 Array<IterVar> new_iter_vars;
310 for (size_t i = 0; i < stmt->iter_vars.size(); ++i) {
311 auto old_iter_var = stmt->iter_vars[i];
312 if (static_cast<int>(i) == info_->rolling_axis) {
313 // All inner loops of the rolling axis has a loop carried dependency
314 // (i.e. each iteration calculation of the rolling axis depends on
315 // the calculation results of all the historical iterations of inner loops),
316 // so annotate the iteration type of the rolling axis as 'opaque',
317 // avoid the iterative range of its inner loop from being compressed
318 // during lowering phase.
319 IterVar new_iter_var =
320 IterVar(old_iter_var->dom, old_iter_var->var, IterVarType::kOpaque);
321 new_iter_vars.push_back(new_iter_var);
322 } else {
323 new_iter_vars.push_back(old_iter_var);
324 }
325 }
326 Map<Var, Buffer> buffer_data_to_buffer = {{info_->new_buffer->data, info_->new_buffer}};
327 auto infered_access_regions = GetBlockReadWriteRegion(stmt, buffer_data_to_buffer);
328
329 n->iter_vars = std::move(new_iter_vars);
330 RewriteAccessRegion(&n->reads, infered_access_regions[0]);
331 RewriteAccessRegion(&n->writes, infered_access_regions[1]);
332 }
333 info_->block_reuse.Set(old_stmt, stmt);
334 return std::move(stmt);
335 }
336
337 Stmt VisitStmt_(const BlockRealizeNode* realize) final {
338 BlockRealize stmt = Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(realize));
339 // Append block predicate to avoid recomputing elements.
340 if (rewrite_block_predicate_) {
341 rewrite_block_predicate_ = false;
342 PrimExpr condition = stmt->predicate;
343 for (size_t i = 0; i < info_->axis_iter_vars.size(); ++i) {
344 auto iter_var = info_->axis_iter_vars[i];
345 if (iter_var && info_->axis_overlaps[i] > 0) {
346 Var var = iter_var.value();
347 const Map<Var, arith::IntSet> dmap = {std::make_pair(var, arith::IntSet::Interval(0, 0))};
348 auto iter_value = realize->iter_values[i];
349 arith::Analyzer analyzer;
350 auto term_2 = analyzer.int_set(iter_value, dmap).min();
351 condition = analyzer.Simplify(
352 And(condition, Or(LT(var, 1), GE(term_2, info_->axis_overlaps[i]))));
353 }
354 }
355 BlockRealizeNode* n = stmt.CopyOnWrite();
356 n->predicate = condition;
357 }
358 return std::move(stmt);
359 }
360
361 Stmt VisitStmt_(const BufferStoreNode* op) final {
362 BufferStore stmt = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
363 if (stmt->buffer.same_as(info_->old_buffer)) {
364 BufferStoreNode* n = stmt.CopyOnWrite();
365 RewriteBufferAccess(&n->buffer, &n->indices);
366 // Need to add predicate to the current block to avoid recomputing elements.
367 rewrite_block_predicate_ = true;
368 }
369 return std::move(stmt);
370 }
371
372 PrimExpr VisitExpr_(const BufferLoadNode* op) final {
373 BufferLoad stmt = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
374 if (stmt->buffer.same_as(info_->old_buffer)) {
375 BufferLoadNode* n = stmt.CopyOnWrite();
376 RewriteBufferAccess(&n->buffer, &n->indices);
377 }
378 return std::move(stmt);
379 }
380
381 private:
382 const StmtSRef& scope_sref_;
383 RollingBufferInfo* info_;
384 bool rewrite_block_predicate_ = false;
385};
386
387} // namespace
388
389void RollingBuffer(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index) {
390 /*!
391 * Check
392 * - The block is not an output block.
393 * - The block has only RAW dependencies.
394 * - The block is tiled and there is access overlap between adjacent tiles.
395 * Mutate
396 * - Select the outermost rollable axis appeared in the block's loop nest
397 * as the 'rolling axis', trim the target buffer from the rolling axis.
398 * - Use modulo arithmetic to modify the target buffer's read and load
399 * indices to circularize the buffer along the rolling dimension.
400 * - Append block predicate to avoid recomputing overlapping elements.
401 */
402 Map<Var, arith::IntSet> dom_map;
403 const BlockRealize& realize = GetBlockRealize(self, block_sref);
404 const Block& block = realize->block;
405
406 // Step 1. Checking index, getting the target buffer region and the parent scope.
407 const BufferRegion& buffer_region =
408 GetNthAccessBufferRegion(self, block, write_buffer_index, BufferIndexType::kWrite);
409 StmtSRef scope_root_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false);
410 // Step 2. Check if the target block is not an output block and has only RAW dependencies.
411 CheckNotOutputBlock(self, block_sref, scope_root_sref);
412 RollingBufferDependencyError::Check(self, block_sref, scope_root_sref);
413
414 // Step 3. Find the lca of the access location of the target buffer and relax the buffer
415 Array<StmtSRef> loop_srefs = GetLoops(block_sref);
416 Array<StmtSRef> consumers_sref = GetConsumers(self, block_sref);
417 consumers_sref.push_back(block_sref);
418 StmtSRef lca = GetSRefLowestCommonAncestor(consumers_sref);
419 if (!lca->StmtAs<ForNode>()) {
420 throw RollingBufferInsertionError(self->mod, buffer_region->buffer, block);
421 }
422
423 for (auto it = loop_srefs.rbegin(); it != loop_srefs.rend(); ++it) {
424 auto stmt = *it;
425 // Stop at the lca of all the rolling_buffer access points;
426 if (stmt == lca) {
427 break;
428 }
429 For cur_loop = GetRef<For>(stmt->StmtAs<ForNode>());
430 Range range = Range::FromMinExtent(cur_loop->min, cur_loop->extent);
431 dom_map.Set(cur_loop->loop_var, arith::IntSet::FromRange(range));
432 }
433 BufferRegion relaxed_region = GetRelaxedBufferRegion(realize, buffer_region, dom_map);
434
435 // Step 4. Find a valid rolling axis and collect bound overlaps on the target buffer.
436 RollingBufferInfo info = RollingBufferInfoCollector::CheckAndGetRollingBufferInfo(
437 self->mod, block_sref, relaxed_region);
438 // Step 5. Mutate IR to apply rolling access pattern.
439 Stmt new_scope_root = RollingBufferRewriter::Rewrite(scope_root_sref, &info);
440
441 // Step 6. Update schedule states
442 self->Replace(scope_root_sref, new_scope_root, info.block_reuse);
443 // Step 7. Regenerate block info from the root block, because `region_cover` for the target block
444 // and `stage_pipeline` for the root block are no longer satisfied after rolling buffer injection.
445 self->UpdateScopeBlockInfo(tir::GetBlockRealize(self, self->stmt2ref.at(new_scope_root.get())));
446}
447
448struct RollingBufferTraits : public UnpackedInstTraits<RollingBufferTraits> {
449 static constexpr const char* kName = "RollingBuffer";
450 static constexpr bool kIsPure = false;
451
452 private:
453 static constexpr size_t kNumInputs = 1;
454 static constexpr size_t kNumAttrs = 1;
455 static constexpr size_t kNumDecisions = 0;
456
457 static void UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer write_buffer_index) {
458 return sch->RollingBuffer(block, write_buffer_index.IntValue());
459 }
460
461 static String UnpackedAsPython(Array<String> outputs, String block, Integer write_buffer_index) {
462 PythonAPICall py("rolling_buffer");
463 py.Input("block", block);
464 py.Input("write_buffer_index", write_buffer_index);
465 return py.Str();
466 }
467
468 template <typename>
469 friend struct ::tvm::tir::UnpackedInstTraits;
470};
471
472TVM_REGISTER_INST_KIND_TRAITS(RollingBufferTraits);
473} // namespace tir
474} // namespace tvm
475