1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file inject_rolling_buffer.cc |
22 | * \brief Inject rolling buffer statements. |
23 | |
24 | Rolling buffers are buffers where one of the dimensions has been made into |
25 | a circular buffer. Two optimizations are implemented in order to accomplish |
26 | this: sliding window and storage folding. In particular, the sliding window |
27 | optimization is applied to the entire buffer (to avoid recomputing elements) |
28 | and storage folding is then applied to just the rolling dimension. |
29 | |
30 | Rolling buffers must be inside a loop with only part of the buffer used per |
31 | iteration. The outermost axis will be rolled over. |
32 | |
33 | For more information, see the RFC: |
34 | https://discuss.tvm.apache.org/t/rfc-introducing-a-rolling-buffer-scheduling-primitive/9836 |
35 | */ |
36 | #include <tvm/arith/analyzer.h> |
37 | #include <tvm/runtime/registry.h> |
38 | #include <tvm/tir/stmt_functor.h> |
39 | #include <tvm/tir/transform.h> |
40 | |
41 | #include "ir_utils.h" |
42 | |
43 | namespace tvm { |
44 | namespace tir { |
45 | |
46 | using arith::IntSet; |
47 | |
48 | struct RollingBufferInfo { |
49 | int rolling_axis; |
50 | int rolling_extent; |
51 | std::vector<int> axis_overlaps; |
52 | std::vector<Optional<Var>> axis_iter_vars; |
53 | }; |
54 | |
55 | class RollingBufferInjector : public StmtExprMutator { |
56 | std::vector<For> for_loops{}; |
57 | std::set<Buffer> rolling_buffers{}; |
58 | std::map<Buffer, BufferRealize> buffer_to_buffer_realize{}; |
59 | std::map<Buffer, std::vector<AttrStmt>> buffer_to_attrs{}; |
60 | std::map<Buffer, RollingBufferInfo> rolling_buffer_to_info{}; |
61 | // The actual key type is Var, ObjectRef has been used because |
62 | // of the ambiguous overload for ‘operator<’ |
63 | std::map<ObjectRef, std::vector<BufferRealize>> hoist_buffer_to_for{}; |
64 | |
65 | public: |
66 | RollingBufferInjector() {} |
67 | |
68 | Stmt Inject(Stmt stmt) { return ConvertSSA(operator()(std::move(stmt))); } |
69 | |
70 | Stmt VisitStmt_(const ForNode* op) final { |
71 | // Manage the stack of iter_vars |
72 | for_loops.push_back(GetRef<For>(op)); |
73 | |
74 | auto stmt{StmtExprMutator::VisitStmt_(op)}; |
75 | op = stmt.as<ForNode>(); |
76 | |
77 | // Manage the stack of iter_vars |
78 | for_loops.pop_back(); |
79 | |
80 | auto it{hoist_buffer_to_for.find(op->loop_var)}; |
81 | if (it != hoist_buffer_to_for.end()) { |
82 | // If the loop corresponds to an iter_var that needs a BufferRealize |
83 | // hoisting to its scope, perform the hoisting |
84 | Stmt body{GetRef<For>(op)}; |
85 | for (auto realise : it->second) { |
86 | auto attrs{buffer_to_attrs[realise->buffer]}; |
87 | Stmt new_realize{BufferRealize(realise->buffer, realise->bounds, realise->condition, body, |
88 | realise->span)}; |
89 | // The attributes attached to the BufferRealize need hoisting too |
90 | for (auto attr : attrs) { |
91 | if (attr->attr_key == attr::rolling_buffer_scope) { |
92 | continue; |
93 | } |
94 | new_realize = AttrStmt(attr->node, attr->attr_key, attr->value, new_realize, attr->span); |
95 | } |
96 | body = new_realize; |
97 | } |
98 | return body; |
99 | } else { |
100 | return stmt; |
101 | } |
102 | } |
103 | |
104 | Stmt VisitStmt_(const AttrStmtNode* op) final { |
105 | if (auto b = op->node.as<BufferNode>()) { |
106 | auto buffer = GetRef<Buffer>(b); |
107 | // Keep a dictionary associating attribute statements with the buffers |
108 | // they reference. We'll need this if the buffer gets hoisted and we |
109 | // need to hoist all of its attributes at the same time. |
110 | buffer_to_attrs[buffer].push_back(GetRef<AttrStmt>(op)); |
111 | |
112 | if (op->attr_key == attr::rolling_buffer_scope && Downcast<IntImm>(op->value)->value) { |
113 | // If the attribute is indicating that a buffer should be a rolling |
114 | // buffer, then update the rolling_buffers set to include the buffer |
115 | rolling_buffers.insert(buffer); |
116 | |
117 | auto it{buffer_to_buffer_realize.find(buffer)}; |
118 | ICHECK(it != buffer_to_buffer_realize.end()) |
119 | << "Rolling buffer injection failed: no BufferRealize found" ; |
120 | BufferRealize buffer_realize = it->second; |
121 | |
122 | // If a BufferRealize has been identified as needing to be made into |
123 | // a rolling buffer, begin the analysis. |
124 | std::vector<Optional<Var>> bound_iter_vars{}; |
125 | std::vector<int> bound_overlaps{}; |
126 | // We use the bound information of the BufferRealize to calculate |
127 | // how we can legally roll |
128 | auto stride{0}; |
129 | auto divisor{1}; |
130 | Optional<Var> iter_var{}; |
131 | for (auto bound : buffer_realize->bounds) { |
132 | divisor = 1; |
133 | if (auto floor_div = bound->min.as<FloorDivNode>()) { |
134 | // Handle the case of fractional strides |
135 | // They take this form: floordiv(hh.outer, 2) |
136 | // Strip the floordiv and keep track of the divisor |
137 | divisor = Downcast<IntImm>(floor_div->b)->value; |
138 | bound = Range::FromMinExtent(floor_div->a, bound->extent, bound->span); |
139 | } |
140 | if (bound->min.as<IntImmNode>()) { |
141 | // If the bound is an int, we can't roll over it |
142 | iter_var = nullptr; |
143 | } else if (auto var = bound->min.as<VarNode>()) { |
144 | // If the bound is just a Var, that implies the stride is 1 |
145 | iter_var = GetRef<Var>(var); |
146 | stride = 1; |
147 | } else { |
148 | // Otherwise, it's the iter var multiplied by the stride |
149 | // If not we're in unknown behaviour, so assert |
150 | auto mul = bound->min.as<MulNode>(); |
151 | ICHECK(mul) << "Rolling buffer injection failed: the buffer striding is unsupported" ; |
152 | auto a = mul->a.as<VarNode>(); |
153 | ICHECK(a) << "Rolling buffer injection failed: the buffer striding is unsupported" ; |
154 | auto b = mul->b.as<IntImmNode>(); |
155 | ICHECK(b) << "Rolling buffer injection failed: the buffer striding is unsupported" ; |
156 | iter_var = GetRef<Var>(a); |
157 | stride = b->value; |
158 | } |
159 | stride = std::ceil(static_cast<float>(stride) / divisor); |
160 | bound_iter_vars.push_back(iter_var); |
161 | if (iter_var) { |
162 | bound_overlaps.push_back(Downcast<IntImm>(bound->extent)->value - stride); |
163 | } else { |
164 | bound_overlaps.push_back(0); |
165 | } |
166 | } |
167 | // Pick the outermost iter_var that's mentioned in the bounds |
168 | // to be the rolling axis |
169 | Optional<Var> roll_iter_var{}; |
170 | int roll_axis{1}; |
171 | for (auto loop : for_loops) { |
172 | auto loop_var{loop->loop_var}; |
173 | iter_var = loop_var; |
174 | |
175 | auto it{std::find_if( |
176 | bound_iter_vars.begin(), bound_iter_vars.end(), |
177 | [&](Optional<Var> var) { return var && (var.get() == loop_var.get()); })}; |
178 | |
179 | if (it != bound_iter_vars.end()) { |
180 | auto i{std::distance(bound_iter_vars.begin(), it)}; |
181 | roll_iter_var = loop_var; |
182 | roll_axis = i; |
183 | break; |
184 | } |
185 | } |
186 | // We must have found an axis to roll over |
187 | ICHECK(roll_iter_var) << "Rolling buffer injection failed: no rolling axis found" ; |
188 | ICHECK(roll_axis != -1) << "Rolling buffer injection failed: no rolling axis found" ; |
189 | |
190 | RollingBufferInfo rolling_buffer_info = { |
191 | roll_axis, |
192 | static_cast<int>(Downcast<IntImm>(buffer_realize->bounds[roll_axis]->extent)->value), |
193 | bound_overlaps, |
194 | bound_iter_vars, |
195 | }; |
196 | rolling_buffer_to_info[buffer] = rolling_buffer_info; |
197 | Array<Range> new_bounds{}; |
198 | auto shape{buffer->shape}; |
199 | for (size_t i{0}; i < shape.size(); ++i) { |
200 | auto extent{shape[i]}; |
201 | if (static_cast<int>(i) == rolling_buffer_info.rolling_axis) { |
202 | new_bounds.push_back(Range(0, rolling_buffer_info.rolling_extent)); |
203 | } else { |
204 | new_bounds.push_back(Range(0, extent)); |
205 | } |
206 | } |
207 | BufferRealize new_realize{BufferRealize(buffer, new_bounds, buffer_realize->condition, |
208 | buffer_realize->body, buffer_realize->span)}; |
209 | hoist_buffer_to_for[iter_var.value()].push_back(new_realize); |
210 | } |
211 | } |
212 | |
213 | auto stmt{StmtExprMutator::VisitStmt_(op)}; |
214 | op = stmt.as<AttrStmtNode>(); |
215 | |
216 | if (rolling_buffers.count(GetRef<Buffer>(op->node.as<BufferNode>()))) { |
217 | // Remove the attribute statements attached to rolling buffers |
218 | // because they will have been hoisted to the relevant rolling |
219 | // scope |
220 | return op->body; |
221 | } else { |
222 | return stmt; |
223 | } |
224 | } |
225 | |
226 | Stmt VisitStmt_(const BufferRealizeNode* op) final { |
227 | buffer_to_buffer_realize.insert({op->buffer, GetRef<BufferRealize>(op)}); |
228 | |
229 | auto stmt{StmtExprMutator::VisitStmt_(op)}; |
230 | op = stmt.as<BufferRealizeNode>(); |
231 | |
232 | if (rolling_buffers.count(op->buffer)) { |
233 | // Remove the original BufferRealize for rolling buffers |
234 | // because they will have been hoisted to the relevant rolling |
235 | // scope |
236 | return op->body; |
237 | } else { |
238 | return stmt; |
239 | } |
240 | } |
241 | |
242 | Stmt VisitStmt_(const BufferStoreNode* op) final { |
243 | auto stmt{StmtExprMutator::VisitStmt_(op)}; |
244 | op = stmt.as<BufferStoreNode>(); |
245 | |
246 | auto it{rolling_buffer_to_info.find(op->buffer)}; |
247 | if (it != rolling_buffer_to_info.end()) { |
248 | auto rolling_buffer_info{it->second}; |
249 | std::vector<PrimExpr> indices{}; |
250 | // First modify the access indices to use modulo arithmetic |
251 | // for the rolling axis |
252 | for (size_t i{0}; i < op->indices.size(); ++i) { |
253 | auto index{op->indices[i]}; |
254 | if (static_cast<int>(i) == rolling_buffer_info.rolling_axis) { |
255 | indices.push_back(FloorMod(index, rolling_buffer_info.rolling_extent)); |
256 | } else { |
257 | indices.push_back(index); |
258 | } |
259 | } |
260 | Stmt buffer_store = BufferStore(op->buffer, op->value, indices, op->span); |
261 | // Then wrap the BufferStores in some Ifs to avoid recomputing elements |
262 | for (size_t i{0}; i < rolling_buffer_info.axis_iter_vars.size(); ++i) { |
263 | auto iter_var{rolling_buffer_info.axis_iter_vars[i]}; |
264 | if (iter_var && rolling_buffer_info.axis_overlaps[i] > 0) { |
265 | Var var{iter_var.value()}; |
266 | const Map<Var, IntSet> dmap{std::make_pair(var, IntSet::Interval(0, 0))}; |
267 | auto term_2{arith::Analyzer{}.int_set(op->indices[i], dmap).min()}; |
268 | auto condition = Or(LT(var, 1), GE(term_2, rolling_buffer_info.axis_overlaps[i])); |
269 | buffer_store = IfThenElse(likely(condition), buffer_store); |
270 | } |
271 | } |
272 | return buffer_store; |
273 | } else { |
274 | return stmt; |
275 | } |
276 | } |
277 | |
278 | PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
279 | auto expr{StmtExprMutator::VisitExpr_(op)}; |
280 | op = expr.as<BufferLoadNode>(); |
281 | |
282 | auto it{rolling_buffer_to_info.find(op->buffer)}; |
283 | if (it != rolling_buffer_to_info.end()) { |
284 | auto rolling_buffer_info{it->second}; |
285 | std::vector<PrimExpr> indices{}; |
286 | // Modify the access indices to use modulo arithmetic |
287 | // for the rolling axis |
288 | for (size_t i{0}; i < op->indices.size(); ++i) { |
289 | auto index{op->indices[i]}; |
290 | if (static_cast<int>(i) == rolling_buffer_info.rolling_axis) { |
291 | indices.push_back(FloorMod(index, rolling_buffer_info.rolling_extent)); |
292 | } else { |
293 | indices.push_back(index); |
294 | } |
295 | } |
296 | return BufferLoad(op->buffer, indices, op->span); |
297 | } else { |
298 | return expr; |
299 | } |
300 | } |
301 | }; // namespace tir |
302 | |
303 | namespace transform { |
304 | |
305 | Pass InjectRollingBuffer() { |
306 | auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { |
307 | auto* n = f.CopyOnWrite(); |
308 | n->body = RollingBufferInjector().Inject(std::move(n->body)); |
309 | return f; |
310 | }; |
311 | return CreatePrimFuncPass(pass_func, 0, "tir.InjectRollingBuffer" , {}); |
312 | } |
313 | |
314 | TVM_REGISTER_GLOBAL("tir.transform.InjectRollingBuffer" ).set_body_typed(InjectRollingBuffer); |
315 | |
316 | } // namespace transform |
317 | |
318 | } // namespace tir |
319 | } // namespace tvm |
320 | |