1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \brief Inject double buffering optimization for data fetch. |
22 | * \file inject_double_buffer.cc |
23 | */ |
24 | #include <tvm/runtime/registry.h> |
25 | #include <tvm/tir/op.h> |
26 | #include <tvm/tir/stmt_functor.h> |
27 | #include <tvm/tir/transform.h> |
28 | |
29 | #include "ir_utils.h" |
30 | |
31 | namespace tvm { |
32 | namespace tir { |
33 | |
34 | struct InjectDoubleBufferConfigNode : public tvm::AttrsNode<InjectDoubleBufferConfigNode> { |
35 | int split_loop; |
36 | |
37 | TVM_DECLARE_ATTRS(InjectDoubleBufferConfigNode, "tir.transform.InjectDoubleBufferConfig" ) { |
38 | TVM_ATTR_FIELD(split_loop).describe("Split loop factors" ).set_default(1); |
39 | } |
40 | }; |
41 | |
42 | class InjectDoubleBufferConfig : public Attrs { |
43 | public: |
44 | TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(InjectDoubleBufferConfig, Attrs, |
45 | InjectDoubleBufferConfigNode); |
46 | }; |
47 | |
48 | TVM_REGISTER_NODE_TYPE(InjectDoubleBufferConfigNode); |
49 | TVM_REGISTER_PASS_CONFIG_OPTION("tir.InjectDoubleBuffer" , InjectDoubleBufferConfig); |
50 | |
51 | // Detect double buffer variables. |
52 | class DoubleBufferDetector : public StmtExprVisitor { |
53 | public: |
54 | void VisitStmt_(const AttrStmtNode* op) final { |
55 | if (op->attr_key == attr::double_buffer_scope) { |
56 | touched_.insert(op->node.as<VarNode>()); |
57 | StmtExprVisitor::VisitStmt_(op); |
58 | } else { |
59 | StmtExprVisitor::VisitStmt_(op); |
60 | } |
61 | } |
62 | |
63 | void VisitExpr_(const VarNode* op) final { |
64 | if (touched_.count(op)) { |
65 | touched_.erase(op); |
66 | } |
67 | } |
68 | // The set of touched variable. |
69 | std::unordered_set<const VarNode*> touched_; |
70 | }; |
71 | |
72 | class StripDoubleBufferWrite : public StmtMutator { |
73 | public: |
74 | Stmt VisitStmt_(const AttrStmtNode* op) final { |
75 | if (op->attr_key == attr::double_buffer_write) { |
76 | return VisitStmt(op->body); |
77 | } else { |
78 | return StmtMutator::VisitStmt_(op); |
79 | } |
80 | } |
81 | }; |
82 | |
83 | class DoubleBufferInjector : public StmtExprMutator { |
84 | public: |
85 | explicit DoubleBufferInjector(int split_loop) : split_loop_(split_loop) {} |
86 | |
87 | Stmt Inject(Stmt stmt) { |
88 | DoubleBufferDetector detector; |
89 | detector(stmt); |
90 | if (detector.touched_.empty()) return stmt; |
91 | for (const VarNode* v : detector.touched_) { |
92 | dbuffer_info_[v] = StorageEntry(); |
93 | } |
94 | return ConvertSSA(operator()(std::move(stmt))); |
95 | } |
96 | |
97 | Stmt VisitStmt_(const AttrStmtNode* op) final { |
98 | if (op->attr_key == attr::double_buffer_scope) { |
99 | return MakeProducer(op); |
100 | } else { |
101 | return StmtExprMutator::VisitStmt_(op); |
102 | } |
103 | } |
104 | |
105 | Stmt VisitStmt_(const AllocateNode* op) final { |
106 | const VarNode* buf = op->buffer_var.as<VarNode>(); |
107 | auto it = dbuffer_info_.find(buf); |
108 | if (it != dbuffer_info_.end()) { |
109 | it->second.scope = GetPtrStorageScope(op->buffer_var); |
110 | |
111 | ICHECK_EQ(op->extents.size(), 1) << "InjectDoubleBuffer expects flat 1-d buffers. " |
112 | << "Has StorageFlatten (TE-based schedules) or " |
113 | << "FlattenBuffer (TIR-based schedules) been run?" ; |
114 | it->second.stride = op->extents[0]; |
115 | Stmt stmt = StmtExprMutator::VisitStmt_(op); |
116 | op = stmt.as<AllocateNode>(); |
117 | |
118 | Array<PrimExpr> new_extents = {op->extents[0] * make_const(op->extents[0].dtype(), 2)}; |
119 | ICHECK(it->second.loop != nullptr); |
120 | auto& alloc_nest = loop_allocs_[it->second.loop]; |
121 | alloc_nest.emplace_back( |
122 | Allocate(op->buffer_var, op->dtype, new_extents, op->condition, Evaluate(0))); |
123 | return op->body; |
124 | } else { |
125 | return StmtExprMutator::VisitStmt_(op); |
126 | } |
127 | } |
128 | |
129 | Stmt VisitStmt_(const ForNode* op) final { |
130 | loop_nest_.push_back(op); |
131 | Stmt stmt = StmtExprMutator::VisitStmt_(op); |
132 | auto it = loop_pre_.find(op); |
133 | if (it != loop_pre_.end()) { |
134 | const ForNode* old_loop = stmt.as<ForNode>(); |
135 | if (split_loop_ != 0) { |
136 | // Explicitly unroll the loop |
137 | ICHECK(split_loop_ % 2 == 0 || split_loop_ == 1) |
138 | << "It is better to split with multiple of 2" ; |
139 | ICHECK(is_zero(old_loop->min)); |
140 | PrimExpr zero = old_loop->min; |
141 | PrimExpr new_ext = old_loop->extent - make_const(old_loop->loop_var.dtype(), 1); |
142 | PrimExpr factor = make_const(new_ext.dtype(), split_loop_); |
143 | PrimExpr outer_ext = new_ext / factor; |
144 | PrimExpr tail_base = outer_ext * factor; |
145 | Var outer_var(old_loop->loop_var->name_hint + ".outer" , old_loop->loop_var.dtype()); |
146 | std::unordered_map<const VarNode*, PrimExpr> vmap; |
147 | std::vector<Stmt> loop_seq; |
148 | for (int32_t i = 0; i < split_loop_; ++i) { |
149 | vmap[old_loop->loop_var.get()] = outer_var * factor + make_const(factor.dtype(), i); |
150 | loop_seq.emplace_back(Substitute(old_loop->body, vmap)); |
151 | } |
152 | Stmt loop = For(outer_var, zero, outer_ext, old_loop->kind, SeqStmt::Flatten(loop_seq)); |
153 | // tail |
154 | std::vector<Stmt> tail_seq; |
155 | Stmt tail_body = StripDoubleBufferWrite()(old_loop->body); |
156 | for (int32_t i = 0; i < split_loop_; ++i) { |
157 | PrimExpr idx = tail_base + make_const(tail_base.dtype(), i); |
158 | vmap[old_loop->loop_var.get()] = idx; |
159 | tail_seq.emplace_back(IfThenElse(idx < old_loop->extent, Substitute(tail_body, vmap))); |
160 | } |
161 | stmt = SeqStmt::Flatten(loop, tail_seq); |
162 | } |
163 | stmt = SeqStmt::Flatten(it->second, stmt); |
164 | } |
165 | it = loop_allocs_.find(op); |
166 | if (it != loop_allocs_.end()) { |
167 | stmt = MergeNest(it->second, stmt); |
168 | } |
169 | loop_nest_.pop_back(); |
170 | return stmt; |
171 | } |
172 | |
173 | PrimExpr VisitExpr_(const LoadNode* op) final { |
174 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
175 | } |
176 | |
177 | Stmt VisitStmt_(const StoreNode* op) final { |
178 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
179 | } |
180 | |
181 | Stmt VisitStmt_(const BufferStoreNode* op) final { |
182 | auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); |
183 | |
184 | auto it = dbuffer_info_.find(node->buffer->data.get()); |
185 | if (it != dbuffer_info_.end()) { |
186 | const StorageEntry& e = it->second; |
187 | ICHECK(in_double_buffer_scope_); |
188 | ICHECK(e.switch_write_var.defined()); |
189 | |
190 | ICHECK_EQ(node->indices.size(), 1) << "InjectDoubleBuffer expects flat 1-d buffers. " |
191 | << "Has StorageFlatten (TE-based schedules) or " |
192 | << "FlattenBuffer (TIR-based schedules) been run?" ; |
193 | |
194 | auto writer = node.CopyOnWrite(); |
195 | writer->buffer = GetRemappedBuffer(node->buffer, e.stride); |
196 | writer->indices = {e.switch_write_var * e.stride + node->indices[0]}; |
197 | } |
198 | |
199 | return std::move(node); |
200 | } |
201 | |
202 | PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
203 | auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); |
204 | |
205 | auto it = dbuffer_info_.find(node->buffer->data.get()); |
206 | if (it != dbuffer_info_.end()) { |
207 | const StorageEntry& e = it->second; |
208 | ICHECK(e.switch_read_var.defined()); |
209 | |
210 | ICHECK_EQ(node->indices.size(), 1) << "InjectDoubleBuffer expects flat 1-d buffers. " |
211 | << "Has StorageFlatten (TE-based schedules) or " |
212 | << "FlattenBuffer (TIR-based schedules) been run?" ; |
213 | |
214 | auto writer = node.CopyOnWrite(); |
215 | writer->buffer = GetRemappedBuffer(node->buffer, e.stride); |
216 | writer->indices = {e.switch_read_var * e.stride + node->indices[0]}; |
217 | } |
218 | |
219 | return std::move(node); |
220 | } |
221 | |
222 | Buffer GetRemappedBuffer(Buffer buf, PrimExpr stride) { |
223 | auto key = buf.get(); |
224 | auto it = buf_remap_.find(key); |
225 | if (it != buf_remap_.end()) { |
226 | return it->second; |
227 | } |
228 | |
229 | ICHECK(stride.defined()); |
230 | // TODO(Lunderberg): Move this pass to before |
231 | // StorageFlatten/FlattenBuffer. That will simplify the |
232 | // implementation, to be the insertion of a new dimension for the |
233 | // buffer, rather than adjusting the other indices. |
234 | ICHECK_EQ(buf->shape.size(), 1) << "InjectDoubleBuffer expects flat 1-d buffers. " |
235 | << "Has StorageFlatten (TE-based schedules) or " |
236 | << "FlattenBuffer (TIR-based schedules) been run?" ; |
237 | auto writer = buf.CopyOnWrite(); |
238 | writer->shape = {buf->shape[0] * stride}; |
239 | |
240 | buf_remap_[key] = buf; |
241 | return buf; |
242 | } |
243 | |
244 | PrimExpr VisitExpr_(const VarNode* op) final { |
245 | ICHECK(!dbuffer_info_.count(op)); |
246 | return GetRef<PrimExpr>(op); |
247 | } |
248 | |
249 | private: |
250 | Stmt MakeProducer(const AttrStmtNode* op) { |
251 | const Var buffer = Downcast<Var>(op->node); |
252 | ICHECK_NE(loop_nest_.size(), 0U) << "Double buffer scope must be inside a loop" ; |
253 | auto it = dbuffer_info_.find(buffer.get()); |
254 | if (it == dbuffer_info_.end()) { |
255 | LOG(WARNING) << "Skip double buffer scope " << op->node; |
256 | return this->VisitStmt(op->body); |
257 | } |
258 | StorageEntry& e = it->second; |
259 | e.loop = loop_nest_.back(); |
260 | PrimExpr zero = make_const(e.loop->loop_var.dtype(), 0); |
261 | PrimExpr one = make_const(e.loop->loop_var.dtype(), 1); |
262 | PrimExpr two = make_const(e.loop->loop_var.dtype(), 2); |
263 | PrimExpr loop_shift = e.loop->loop_var + one; |
264 | e.switch_write_var = Var(e.loop->loop_var->name_hint + ".db" , e.loop->loop_var.dtype()); |
265 | e.switch_read_var = indexmod(e.loop->loop_var, two); |
266 | in_double_buffer_scope_ = true; |
267 | Stmt body = this->VisitStmt(op->body); |
268 | in_double_buffer_scope_ = false; |
269 | std::unordered_map<const VarNode*, PrimExpr> vmap; |
270 | vmap[e.switch_write_var.get()] = zero; |
271 | vmap[e.loop->loop_var.get()] = zero; |
272 | loop_pre_[e.loop].emplace_back(Substitute(body, vmap)); |
273 | vmap[e.loop->loop_var.get()] = loop_shift; |
274 | vmap[e.switch_write_var.get()] = indexmod(loop_shift, two); |
275 | body = Substitute(body, vmap); |
276 | body = AttrStmt(buffer, attr::double_buffer_write, 1, body); |
277 | body = IfThenElse(loop_shift < e.loop->extent, body); |
278 | return body; |
279 | } |
280 | // Storage entry for those who need double buffering. |
281 | struct StorageEntry { |
282 | // The size of the buffer |
283 | PrimExpr stride; |
284 | // The loop we need |
285 | const ForNode* loop{nullptr}; |
286 | // The switch variable. |
287 | Var switch_write_var; |
288 | // The switch variable for reading. |
289 | PrimExpr switch_read_var; |
290 | // The storage scope. |
291 | std::string scope; |
292 | }; |
293 | // Whether split loop |
294 | int32_t split_loop_; |
295 | // Whether we are inside double buffer scope. |
296 | bool in_double_buffer_scope_{false}; |
297 | // The current loop next |
298 | std::vector<const ForNode*> loop_nest_; |
299 | // The allocs to be appended before the loop |
300 | std::unordered_map<const ForNode*, std::vector<Stmt>> loop_allocs_; |
301 | // The stmt to be appended before the loop |
302 | std::unordered_map<const ForNode*, std::vector<Stmt>> loop_pre_; |
303 | // The allocation size of the buffer |
304 | std::unordered_map<const VarNode*, StorageEntry> dbuffer_info_; |
305 | // The updated Buffer objects |
306 | std::unordered_map<const BufferNode*, Buffer> buf_remap_; |
307 | }; |
308 | |
309 | namespace transform { |
310 | |
311 | Pass InjectDoubleBuffer() { |
312 | auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { |
313 | auto* n = f.CopyOnWrite(); |
314 | auto cfg = ctx->GetConfig<InjectDoubleBufferConfig>("tir.InjectDoubleBuffer" ); |
315 | if (!cfg.defined()) { |
316 | cfg = AttrsWithDefaultValues<InjectDoubleBufferConfig>(); |
317 | } |
318 | n->body = DoubleBufferInjector(cfg.value()->split_loop).Inject(std::move(n->body)); |
319 | return f; |
320 | }; |
321 | return CreatePrimFuncPass(pass_func, 0, "tir.InjectDoubleBuffer" , {}); |
322 | } |
323 | |
324 | TVM_REGISTER_GLOBAL("tir.transform.InjectDoubleBuffer" ).set_body_typed(InjectDoubleBuffer); |
325 | |
326 | } // namespace transform |
327 | |
328 | } // namespace tir |
329 | } // namespace tvm |
330 | |