1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \brief Inject double buffering optimization for data fetch.
22 * \file inject_double_buffer.cc
23 */
24#include <tvm/runtime/registry.h>
25#include <tvm/tir/op.h>
26#include <tvm/tir/stmt_functor.h>
27#include <tvm/tir/transform.h>
28
29#include "ir_utils.h"
30
31namespace tvm {
32namespace tir {
33
34struct InjectDoubleBufferConfigNode : public tvm::AttrsNode<InjectDoubleBufferConfigNode> {
35 int split_loop;
36
37 TVM_DECLARE_ATTRS(InjectDoubleBufferConfigNode, "tir.transform.InjectDoubleBufferConfig") {
38 TVM_ATTR_FIELD(split_loop).describe("Split loop factors").set_default(1);
39 }
40};
41
42class InjectDoubleBufferConfig : public Attrs {
43 public:
44 TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(InjectDoubleBufferConfig, Attrs,
45 InjectDoubleBufferConfigNode);
46};
47
48TVM_REGISTER_NODE_TYPE(InjectDoubleBufferConfigNode);
49TVM_REGISTER_PASS_CONFIG_OPTION("tir.InjectDoubleBuffer", InjectDoubleBufferConfig);
50
51// Detect double buffer variables.
52class DoubleBufferDetector : public StmtExprVisitor {
53 public:
54 void VisitStmt_(const AttrStmtNode* op) final {
55 if (op->attr_key == attr::double_buffer_scope) {
56 touched_.insert(op->node.as<VarNode>());
57 StmtExprVisitor::VisitStmt_(op);
58 } else {
59 StmtExprVisitor::VisitStmt_(op);
60 }
61 }
62
63 void VisitExpr_(const VarNode* op) final {
64 if (touched_.count(op)) {
65 touched_.erase(op);
66 }
67 }
68 // The set of touched variable.
69 std::unordered_set<const VarNode*> touched_;
70};
71
72class StripDoubleBufferWrite : public StmtMutator {
73 public:
74 Stmt VisitStmt_(const AttrStmtNode* op) final {
75 if (op->attr_key == attr::double_buffer_write) {
76 return VisitStmt(op->body);
77 } else {
78 return StmtMutator::VisitStmt_(op);
79 }
80 }
81};
82
83class DoubleBufferInjector : public StmtExprMutator {
84 public:
85 explicit DoubleBufferInjector(int split_loop) : split_loop_(split_loop) {}
86
87 Stmt Inject(Stmt stmt) {
88 DoubleBufferDetector detector;
89 detector(stmt);
90 if (detector.touched_.empty()) return stmt;
91 for (const VarNode* v : detector.touched_) {
92 dbuffer_info_[v] = StorageEntry();
93 }
94 return ConvertSSA(operator()(std::move(stmt)));
95 }
96
97 Stmt VisitStmt_(const AttrStmtNode* op) final {
98 if (op->attr_key == attr::double_buffer_scope) {
99 return MakeProducer(op);
100 } else {
101 return StmtExprMutator::VisitStmt_(op);
102 }
103 }
104
105 Stmt VisitStmt_(const AllocateNode* op) final {
106 const VarNode* buf = op->buffer_var.as<VarNode>();
107 auto it = dbuffer_info_.find(buf);
108 if (it != dbuffer_info_.end()) {
109 it->second.scope = GetPtrStorageScope(op->buffer_var);
110
111 ICHECK_EQ(op->extents.size(), 1) << "InjectDoubleBuffer expects flat 1-d buffers. "
112 << "Has StorageFlatten (TE-based schedules) or "
113 << "FlattenBuffer (TIR-based schedules) been run?";
114 it->second.stride = op->extents[0];
115 Stmt stmt = StmtExprMutator::VisitStmt_(op);
116 op = stmt.as<AllocateNode>();
117
118 Array<PrimExpr> new_extents = {op->extents[0] * make_const(op->extents[0].dtype(), 2)};
119 ICHECK(it->second.loop != nullptr);
120 auto& alloc_nest = loop_allocs_[it->second.loop];
121 alloc_nest.emplace_back(
122 Allocate(op->buffer_var, op->dtype, new_extents, op->condition, Evaluate(0)));
123 return op->body;
124 } else {
125 return StmtExprMutator::VisitStmt_(op);
126 }
127 }
128
129 Stmt VisitStmt_(const ForNode* op) final {
130 loop_nest_.push_back(op);
131 Stmt stmt = StmtExprMutator::VisitStmt_(op);
132 auto it = loop_pre_.find(op);
133 if (it != loop_pre_.end()) {
134 const ForNode* old_loop = stmt.as<ForNode>();
135 if (split_loop_ != 0) {
136 // Explicitly unroll the loop
137 ICHECK(split_loop_ % 2 == 0 || split_loop_ == 1)
138 << "It is better to split with multiple of 2";
139 ICHECK(is_zero(old_loop->min));
140 PrimExpr zero = old_loop->min;
141 PrimExpr new_ext = old_loop->extent - make_const(old_loop->loop_var.dtype(), 1);
142 PrimExpr factor = make_const(new_ext.dtype(), split_loop_);
143 PrimExpr outer_ext = new_ext / factor;
144 PrimExpr tail_base = outer_ext * factor;
145 Var outer_var(old_loop->loop_var->name_hint + ".outer", old_loop->loop_var.dtype());
146 std::unordered_map<const VarNode*, PrimExpr> vmap;
147 std::vector<Stmt> loop_seq;
148 for (int32_t i = 0; i < split_loop_; ++i) {
149 vmap[old_loop->loop_var.get()] = outer_var * factor + make_const(factor.dtype(), i);
150 loop_seq.emplace_back(Substitute(old_loop->body, vmap));
151 }
152 Stmt loop = For(outer_var, zero, outer_ext, old_loop->kind, SeqStmt::Flatten(loop_seq));
153 // tail
154 std::vector<Stmt> tail_seq;
155 Stmt tail_body = StripDoubleBufferWrite()(old_loop->body);
156 for (int32_t i = 0; i < split_loop_; ++i) {
157 PrimExpr idx = tail_base + make_const(tail_base.dtype(), i);
158 vmap[old_loop->loop_var.get()] = idx;
159 tail_seq.emplace_back(IfThenElse(idx < old_loop->extent, Substitute(tail_body, vmap)));
160 }
161 stmt = SeqStmt::Flatten(loop, tail_seq);
162 }
163 stmt = SeqStmt::Flatten(it->second, stmt);
164 }
165 it = loop_allocs_.find(op);
166 if (it != loop_allocs_.end()) {
167 stmt = MergeNest(it->second, stmt);
168 }
169 loop_nest_.pop_back();
170 return stmt;
171 }
172
173 PrimExpr VisitExpr_(const LoadNode* op) final {
174 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
175 }
176
177 Stmt VisitStmt_(const StoreNode* op) final {
178 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
179 }
180
181 Stmt VisitStmt_(const BufferStoreNode* op) final {
182 auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
183
184 auto it = dbuffer_info_.find(node->buffer->data.get());
185 if (it != dbuffer_info_.end()) {
186 const StorageEntry& e = it->second;
187 ICHECK(in_double_buffer_scope_);
188 ICHECK(e.switch_write_var.defined());
189
190 ICHECK_EQ(node->indices.size(), 1) << "InjectDoubleBuffer expects flat 1-d buffers. "
191 << "Has StorageFlatten (TE-based schedules) or "
192 << "FlattenBuffer (TIR-based schedules) been run?";
193
194 auto writer = node.CopyOnWrite();
195 writer->buffer = GetRemappedBuffer(node->buffer, e.stride);
196 writer->indices = {e.switch_write_var * e.stride + node->indices[0]};
197 }
198
199 return std::move(node);
200 }
201
202 PrimExpr VisitExpr_(const BufferLoadNode* op) final {
203 auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
204
205 auto it = dbuffer_info_.find(node->buffer->data.get());
206 if (it != dbuffer_info_.end()) {
207 const StorageEntry& e = it->second;
208 ICHECK(e.switch_read_var.defined());
209
210 ICHECK_EQ(node->indices.size(), 1) << "InjectDoubleBuffer expects flat 1-d buffers. "
211 << "Has StorageFlatten (TE-based schedules) or "
212 << "FlattenBuffer (TIR-based schedules) been run?";
213
214 auto writer = node.CopyOnWrite();
215 writer->buffer = GetRemappedBuffer(node->buffer, e.stride);
216 writer->indices = {e.switch_read_var * e.stride + node->indices[0]};
217 }
218
219 return std::move(node);
220 }
221
222 Buffer GetRemappedBuffer(Buffer buf, PrimExpr stride) {
223 auto key = buf.get();
224 auto it = buf_remap_.find(key);
225 if (it != buf_remap_.end()) {
226 return it->second;
227 }
228
229 ICHECK(stride.defined());
230 // TODO(Lunderberg): Move this pass to before
231 // StorageFlatten/FlattenBuffer. That will simplify the
232 // implementation, to be the insertion of a new dimension for the
233 // buffer, rather than adjusting the other indices.
234 ICHECK_EQ(buf->shape.size(), 1) << "InjectDoubleBuffer expects flat 1-d buffers. "
235 << "Has StorageFlatten (TE-based schedules) or "
236 << "FlattenBuffer (TIR-based schedules) been run?";
237 auto writer = buf.CopyOnWrite();
238 writer->shape = {buf->shape[0] * stride};
239
240 buf_remap_[key] = buf;
241 return buf;
242 }
243
244 PrimExpr VisitExpr_(const VarNode* op) final {
245 ICHECK(!dbuffer_info_.count(op));
246 return GetRef<PrimExpr>(op);
247 }
248
249 private:
250 Stmt MakeProducer(const AttrStmtNode* op) {
251 const Var buffer = Downcast<Var>(op->node);
252 ICHECK_NE(loop_nest_.size(), 0U) << "Double buffer scope must be inside a loop";
253 auto it = dbuffer_info_.find(buffer.get());
254 if (it == dbuffer_info_.end()) {
255 LOG(WARNING) << "Skip double buffer scope " << op->node;
256 return this->VisitStmt(op->body);
257 }
258 StorageEntry& e = it->second;
259 e.loop = loop_nest_.back();
260 PrimExpr zero = make_const(e.loop->loop_var.dtype(), 0);
261 PrimExpr one = make_const(e.loop->loop_var.dtype(), 1);
262 PrimExpr two = make_const(e.loop->loop_var.dtype(), 2);
263 PrimExpr loop_shift = e.loop->loop_var + one;
264 e.switch_write_var = Var(e.loop->loop_var->name_hint + ".db", e.loop->loop_var.dtype());
265 e.switch_read_var = indexmod(e.loop->loop_var, two);
266 in_double_buffer_scope_ = true;
267 Stmt body = this->VisitStmt(op->body);
268 in_double_buffer_scope_ = false;
269 std::unordered_map<const VarNode*, PrimExpr> vmap;
270 vmap[e.switch_write_var.get()] = zero;
271 vmap[e.loop->loop_var.get()] = zero;
272 loop_pre_[e.loop].emplace_back(Substitute(body, vmap));
273 vmap[e.loop->loop_var.get()] = loop_shift;
274 vmap[e.switch_write_var.get()] = indexmod(loop_shift, two);
275 body = Substitute(body, vmap);
276 body = AttrStmt(buffer, attr::double_buffer_write, 1, body);
277 body = IfThenElse(loop_shift < e.loop->extent, body);
278 return body;
279 }
280 // Storage entry for those who need double buffering.
281 struct StorageEntry {
282 // The size of the buffer
283 PrimExpr stride;
284 // The loop we need
285 const ForNode* loop{nullptr};
286 // The switch variable.
287 Var switch_write_var;
288 // The switch variable for reading.
289 PrimExpr switch_read_var;
290 // The storage scope.
291 std::string scope;
292 };
293 // Whether split loop
294 int32_t split_loop_;
295 // Whether we are inside double buffer scope.
296 bool in_double_buffer_scope_{false};
297 // The current loop next
298 std::vector<const ForNode*> loop_nest_;
299 // The allocs to be appended before the loop
300 std::unordered_map<const ForNode*, std::vector<Stmt>> loop_allocs_;
301 // The stmt to be appended before the loop
302 std::unordered_map<const ForNode*, std::vector<Stmt>> loop_pre_;
303 // The allocation size of the buffer
304 std::unordered_map<const VarNode*, StorageEntry> dbuffer_info_;
305 // The updated Buffer objects
306 std::unordered_map<const BufferNode*, Buffer> buf_remap_;
307};
308
309namespace transform {
310
311Pass InjectDoubleBuffer() {
312 auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
313 auto* n = f.CopyOnWrite();
314 auto cfg = ctx->GetConfig<InjectDoubleBufferConfig>("tir.InjectDoubleBuffer");
315 if (!cfg.defined()) {
316 cfg = AttrsWithDefaultValues<InjectDoubleBufferConfig>();
317 }
318 n->body = DoubleBufferInjector(cfg.value()->split_loop).Inject(std::move(n->body));
319 return f;
320 };
321 return CreatePrimFuncPass(pass_func, 0, "tir.InjectDoubleBuffer", {});
322}
323
324TVM_REGISTER_GLOBAL("tir.transform.InjectDoubleBuffer").set_body_typed(InjectDoubleBuffer);
325
326} // namespace transform
327
328} // namespace tir
329} // namespace tvm
330