1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file inject_rolling_buffer.cc
22 * \brief Inject rolling buffer statements.
23
24 Rolling buffers are buffers where one of the dimensions has been made into
25 a circular buffer. Two optimizations are implemented in order to accomplish
26 this: sliding window and storage folding. In particular, the sliding window
27 optimization is applied to the entire buffer (to avoid recomputing elements)
28 and storage folding is then applied to just the rolling dimension.
29
30 Rolling buffers must be inside a loop with only part of the buffer used per
31 iteration. The outermost axis will be rolled over.
32
33 For more information, see the RFC:
34 https://discuss.tvm.apache.org/t/rfc-introducing-a-rolling-buffer-scheduling-primitive/9836
35 */
36#include <tvm/arith/analyzer.h>
37#include <tvm/runtime/registry.h>
38#include <tvm/tir/stmt_functor.h>
39#include <tvm/tir/transform.h>
40
41#include "ir_utils.h"
42
43namespace tvm {
44namespace tir {
45
46using arith::IntSet;
47
48struct RollingBufferInfo {
49 int rolling_axis;
50 int rolling_extent;
51 std::vector<int> axis_overlaps;
52 std::vector<Optional<Var>> axis_iter_vars;
53};
54
55class RollingBufferInjector : public StmtExprMutator {
56 std::vector<For> for_loops{};
57 std::set<Buffer> rolling_buffers{};
58 std::map<Buffer, BufferRealize> buffer_to_buffer_realize{};
59 std::map<Buffer, std::vector<AttrStmt>> buffer_to_attrs{};
60 std::map<Buffer, RollingBufferInfo> rolling_buffer_to_info{};
61 // The actual key type is Var, ObjectRef has been used because
62 // of the ambiguous overload for ‘operator<’
63 std::map<ObjectRef, std::vector<BufferRealize>> hoist_buffer_to_for{};
64
65 public:
66 RollingBufferInjector() {}
67
68 Stmt Inject(Stmt stmt) { return ConvertSSA(operator()(std::move(stmt))); }
69
70 Stmt VisitStmt_(const ForNode* op) final {
71 // Manage the stack of iter_vars
72 for_loops.push_back(GetRef<For>(op));
73
74 auto stmt{StmtExprMutator::VisitStmt_(op)};
75 op = stmt.as<ForNode>();
76
77 // Manage the stack of iter_vars
78 for_loops.pop_back();
79
80 auto it{hoist_buffer_to_for.find(op->loop_var)};
81 if (it != hoist_buffer_to_for.end()) {
82 // If the loop corresponds to an iter_var that needs a BufferRealize
83 // hoisting to its scope, perform the hoisting
84 Stmt body{GetRef<For>(op)};
85 for (auto realise : it->second) {
86 auto attrs{buffer_to_attrs[realise->buffer]};
87 Stmt new_realize{BufferRealize(realise->buffer, realise->bounds, realise->condition, body,
88 realise->span)};
89 // The attributes attached to the BufferRealize need hoisting too
90 for (auto attr : attrs) {
91 if (attr->attr_key == attr::rolling_buffer_scope) {
92 continue;
93 }
94 new_realize = AttrStmt(attr->node, attr->attr_key, attr->value, new_realize, attr->span);
95 }
96 body = new_realize;
97 }
98 return body;
99 } else {
100 return stmt;
101 }
102 }
103
104 Stmt VisitStmt_(const AttrStmtNode* op) final {
105 if (auto b = op->node.as<BufferNode>()) {
106 auto buffer = GetRef<Buffer>(b);
107 // Keep a dictionary associating attribute statements with the buffers
108 // they reference. We'll need this if the buffer gets hoisted and we
109 // need to hoist all of its attributes at the same time.
110 buffer_to_attrs[buffer].push_back(GetRef<AttrStmt>(op));
111
112 if (op->attr_key == attr::rolling_buffer_scope && Downcast<IntImm>(op->value)->value) {
113 // If the attribute is indicating that a buffer should be a rolling
114 // buffer, then update the rolling_buffers set to include the buffer
115 rolling_buffers.insert(buffer);
116
117 auto it{buffer_to_buffer_realize.find(buffer)};
118 ICHECK(it != buffer_to_buffer_realize.end())
119 << "Rolling buffer injection failed: no BufferRealize found";
120 BufferRealize buffer_realize = it->second;
121
122 // If a BufferRealize has been identified as needing to be made into
123 // a rolling buffer, begin the analysis.
124 std::vector<Optional<Var>> bound_iter_vars{};
125 std::vector<int> bound_overlaps{};
126 // We use the bound information of the BufferRealize to calculate
127 // how we can legally roll
128 auto stride{0};
129 auto divisor{1};
130 Optional<Var> iter_var{};
131 for (auto bound : buffer_realize->bounds) {
132 divisor = 1;
133 if (auto floor_div = bound->min.as<FloorDivNode>()) {
134 // Handle the case of fractional strides
135 // They take this form: floordiv(hh.outer, 2)
136 // Strip the floordiv and keep track of the divisor
137 divisor = Downcast<IntImm>(floor_div->b)->value;
138 bound = Range::FromMinExtent(floor_div->a, bound->extent, bound->span);
139 }
140 if (bound->min.as<IntImmNode>()) {
141 // If the bound is an int, we can't roll over it
142 iter_var = nullptr;
143 } else if (auto var = bound->min.as<VarNode>()) {
144 // If the bound is just a Var, that implies the stride is 1
145 iter_var = GetRef<Var>(var);
146 stride = 1;
147 } else {
148 // Otherwise, it's the iter var multiplied by the stride
149 // If not we're in unknown behaviour, so assert
150 auto mul = bound->min.as<MulNode>();
151 ICHECK(mul) << "Rolling buffer injection failed: the buffer striding is unsupported";
152 auto a = mul->a.as<VarNode>();
153 ICHECK(a) << "Rolling buffer injection failed: the buffer striding is unsupported";
154 auto b = mul->b.as<IntImmNode>();
155 ICHECK(b) << "Rolling buffer injection failed: the buffer striding is unsupported";
156 iter_var = GetRef<Var>(a);
157 stride = b->value;
158 }
159 stride = std::ceil(static_cast<float>(stride) / divisor);
160 bound_iter_vars.push_back(iter_var);
161 if (iter_var) {
162 bound_overlaps.push_back(Downcast<IntImm>(bound->extent)->value - stride);
163 } else {
164 bound_overlaps.push_back(0);
165 }
166 }
167 // Pick the outermost iter_var that's mentioned in the bounds
168 // to be the rolling axis
169 Optional<Var> roll_iter_var{};
170 int roll_axis{1};
171 for (auto loop : for_loops) {
172 auto loop_var{loop->loop_var};
173 iter_var = loop_var;
174
175 auto it{std::find_if(
176 bound_iter_vars.begin(), bound_iter_vars.end(),
177 [&](Optional<Var> var) { return var && (var.get() == loop_var.get()); })};
178
179 if (it != bound_iter_vars.end()) {
180 auto i{std::distance(bound_iter_vars.begin(), it)};
181 roll_iter_var = loop_var;
182 roll_axis = i;
183 break;
184 }
185 }
186 // We must have found an axis to roll over
187 ICHECK(roll_iter_var) << "Rolling buffer injection failed: no rolling axis found";
188 ICHECK(roll_axis != -1) << "Rolling buffer injection failed: no rolling axis found";
189
190 RollingBufferInfo rolling_buffer_info = {
191 roll_axis,
192 static_cast<int>(Downcast<IntImm>(buffer_realize->bounds[roll_axis]->extent)->value),
193 bound_overlaps,
194 bound_iter_vars,
195 };
196 rolling_buffer_to_info[buffer] = rolling_buffer_info;
197 Array<Range> new_bounds{};
198 auto shape{buffer->shape};
199 for (size_t i{0}; i < shape.size(); ++i) {
200 auto extent{shape[i]};
201 if (static_cast<int>(i) == rolling_buffer_info.rolling_axis) {
202 new_bounds.push_back(Range(0, rolling_buffer_info.rolling_extent));
203 } else {
204 new_bounds.push_back(Range(0, extent));
205 }
206 }
207 BufferRealize new_realize{BufferRealize(buffer, new_bounds, buffer_realize->condition,
208 buffer_realize->body, buffer_realize->span)};
209 hoist_buffer_to_for[iter_var.value()].push_back(new_realize);
210 }
211 }
212
213 auto stmt{StmtExprMutator::VisitStmt_(op)};
214 op = stmt.as<AttrStmtNode>();
215
216 if (rolling_buffers.count(GetRef<Buffer>(op->node.as<BufferNode>()))) {
217 // Remove the attribute statements attached to rolling buffers
218 // because they will have been hoisted to the relevant rolling
219 // scope
220 return op->body;
221 } else {
222 return stmt;
223 }
224 }
225
226 Stmt VisitStmt_(const BufferRealizeNode* op) final {
227 buffer_to_buffer_realize.insert({op->buffer, GetRef<BufferRealize>(op)});
228
229 auto stmt{StmtExprMutator::VisitStmt_(op)};
230 op = stmt.as<BufferRealizeNode>();
231
232 if (rolling_buffers.count(op->buffer)) {
233 // Remove the original BufferRealize for rolling buffers
234 // because they will have been hoisted to the relevant rolling
235 // scope
236 return op->body;
237 } else {
238 return stmt;
239 }
240 }
241
242 Stmt VisitStmt_(const BufferStoreNode* op) final {
243 auto stmt{StmtExprMutator::VisitStmt_(op)};
244 op = stmt.as<BufferStoreNode>();
245
246 auto it{rolling_buffer_to_info.find(op->buffer)};
247 if (it != rolling_buffer_to_info.end()) {
248 auto rolling_buffer_info{it->second};
249 std::vector<PrimExpr> indices{};
250 // First modify the access indices to use modulo arithmetic
251 // for the rolling axis
252 for (size_t i{0}; i < op->indices.size(); ++i) {
253 auto index{op->indices[i]};
254 if (static_cast<int>(i) == rolling_buffer_info.rolling_axis) {
255 indices.push_back(FloorMod(index, rolling_buffer_info.rolling_extent));
256 } else {
257 indices.push_back(index);
258 }
259 }
260 Stmt buffer_store = BufferStore(op->buffer, op->value, indices, op->span);
261 // Then wrap the BufferStores in some Ifs to avoid recomputing elements
262 for (size_t i{0}; i < rolling_buffer_info.axis_iter_vars.size(); ++i) {
263 auto iter_var{rolling_buffer_info.axis_iter_vars[i]};
264 if (iter_var && rolling_buffer_info.axis_overlaps[i] > 0) {
265 Var var{iter_var.value()};
266 const Map<Var, IntSet> dmap{std::make_pair(var, IntSet::Interval(0, 0))};
267 auto term_2{arith::Analyzer{}.int_set(op->indices[i], dmap).min()};
268 auto condition = Or(LT(var, 1), GE(term_2, rolling_buffer_info.axis_overlaps[i]));
269 buffer_store = IfThenElse(likely(condition), buffer_store);
270 }
271 }
272 return buffer_store;
273 } else {
274 return stmt;
275 }
276 }
277
278 PrimExpr VisitExpr_(const BufferLoadNode* op) final {
279 auto expr{StmtExprMutator::VisitExpr_(op)};
280 op = expr.as<BufferLoadNode>();
281
282 auto it{rolling_buffer_to_info.find(op->buffer)};
283 if (it != rolling_buffer_to_info.end()) {
284 auto rolling_buffer_info{it->second};
285 std::vector<PrimExpr> indices{};
286 // Modify the access indices to use modulo arithmetic
287 // for the rolling axis
288 for (size_t i{0}; i < op->indices.size(); ++i) {
289 auto index{op->indices[i]};
290 if (static_cast<int>(i) == rolling_buffer_info.rolling_axis) {
291 indices.push_back(FloorMod(index, rolling_buffer_info.rolling_extent));
292 } else {
293 indices.push_back(index);
294 }
295 }
296 return BufferLoad(op->buffer, indices, op->span);
297 } else {
298 return expr;
299 }
300 }
301}; // namespace tir
302
303namespace transform {
304
305Pass InjectRollingBuffer() {
306 auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
307 auto* n = f.CopyOnWrite();
308 n->body = RollingBufferInjector().Inject(std::move(n->body));
309 return f;
310 };
311 return CreatePrimFuncPass(pass_func, 0, "tir.InjectRollingBuffer", {});
312}
313
314TVM_REGISTER_GLOBAL("tir.transform.InjectRollingBuffer").set_body_typed(InjectRollingBuffer);
315
316} // namespace transform
317
318} // namespace tir
319} // namespace tvm
320