1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file storage_flatten.cc |
22 | * \brief Flattens storage from multi-dimensional array to 1D buffer access |
23 | */ |
24 | // The pass definition originates from Halide pipeline. |
25 | |
26 | #include <tvm/arith/analyzer.h> |
27 | #include <tvm/runtime/device_api.h> |
28 | #include <tvm/runtime/registry.h> |
29 | #include <tvm/target/target_info.h> |
30 | #include <tvm/te/operation.h> |
31 | #include <tvm/tir/buffer.h> |
32 | #include <tvm/tir/builtin.h> |
33 | #include <tvm/tir/expr.h> |
34 | #include <tvm/tir/op.h> |
35 | #include <tvm/tir/stmt.h> |
36 | #include <tvm/tir/stmt_functor.h> |
37 | #include <tvm/tir/transform.h> |
38 | |
39 | #include <unordered_map> |
40 | #include <unordered_set> |
41 | |
42 | #include "../../arith/ir_visitor_with_analyzer.h" |
43 | #include "../../runtime/thread_storage_scope.h" |
44 | #include "arg_binder.h" |
45 | #include "ir_utils.h" |
46 | |
47 | namespace tvm { |
48 | namespace tir { |
49 | |
50 | using arith::IRVisitorWithAnalyzer; |
51 | using runtime::StorageRank; |
52 | using runtime::StorageScope; |
53 | using runtime::ThreadScope; |
54 | |
55 | /* Make buffer realize extents and buffer shapes consistent |
56 | * |
57 | * For external buffers, verify that the extents of BufferRealize |
58 | * nodes match the shape of the external buffer. For internal |
59 | * buffers, rewrite the shape of the Buffer objects to match the |
60 | * extent of the BufferRealize, and rewrite indices of |
61 | * BufferLoad/BufferStore nodes to match. |
62 | */ |
63 | class BufferShapeLegalize : public StmtExprMutator { |
64 | public: |
65 | static transform::Pass Pass() { |
66 | auto pass_func = [](PrimFunc func, IRModule m, transform::PassContext ctx) { |
67 | IRVisitorWithAnalyzer bound_analyzer; |
68 | |
69 | bound_analyzer(func->body); |
70 | |
71 | auto pass = BufferShapeLegalize(func->buffer_map, &bound_analyzer); |
72 | |
73 | auto fptr = func.CopyOnWrite(); |
74 | fptr->body = pass(std::move(fptr->body)); |
75 | if (auto map = func->attrs.GetAttr<Map<Buffer, Array<IndexMap>>>("layout_transform_map" )) { |
76 | func = WithAttr(std::move(func), "layout_transform_map" , pass.UpdateIndexMap(map.value())); |
77 | } |
78 | return func; |
79 | }; |
80 | return transform::CreatePrimFuncPass(pass_func, 0, "tir.BufferShapeLegalize" , {}); |
81 | } |
82 | |
83 | explicit BufferShapeLegalize(const Map<Var, Buffer>& extern_buffer_map, |
84 | IRVisitorWithAnalyzer* bound_analyzer) |
85 | : bound_analyzer_(bound_analyzer) { |
86 | for (auto kv : extern_buffer_map) { |
87 | Buffer buf = kv.second; |
88 | extern_buffers_.insert(buf); |
89 | |
90 | BufferEntry remap; |
91 | remap.remap_to = buf; |
92 | remap.index_offsets = Array<PrimExpr>(buf->shape.size(), 0); |
93 | remap.in_scope = true; |
94 | buf_map_[buf] = remap; |
95 | } |
96 | } |
97 | |
98 | Map<Buffer, Array<IndexMap>> UpdateIndexMap(const Map<Buffer, Array<IndexMap>>& orig) { |
99 | Map<Buffer, Array<IndexMap>> output; |
100 | for (const auto& kv : orig) { |
101 | auto it = buf_map_.find(kv.first); |
102 | if (it != buf_map_.end()) { |
103 | output.Set(it->second.remap_to, kv.second); |
104 | } else { |
105 | output.Set(kv.first, kv.second); |
106 | } |
107 | } |
108 | return output; |
109 | } |
110 | |
111 | PrimExpr VisitExpr_(const VarNode* op) final { |
112 | auto it = var_remap_.find(op); |
113 | if (it != var_remap_.end()) { |
114 | return it->second; |
115 | } else { |
116 | return GetRef<PrimExpr>(op); |
117 | } |
118 | } |
119 | |
120 | Stmt VisitStmt_(const BufferRealizeNode* op) final { |
121 | // BufferRealizeNode for an external buffer serves as an |
122 | // annotation of the external buffers, and should not be changed. |
123 | // Instead, verify that the bounds match the external |
124 | // buffer. |
125 | if (extern_buffers_.count(op->buffer)) { |
126 | CHECK_EQ(op->buffer->shape.size(), op->bounds.size()) |
127 | << "External buffer realize has mismatched dimension" ; |
128 | Stmt stmt = StmtExprMutator::VisitStmt_(op); |
129 | op = stmt.as<BufferRealizeNode>(); |
130 | ICHECK(op); |
131 | |
132 | for (size_t i = 0; i < op->bounds.size(); i++) { |
133 | PrimExpr eq = bound_analyzer_->Simplify(op->buffer->shape[i] == op->bounds[i]->extent); |
134 | std::ostringstream ss; |
135 | ss << "Dim " << i << " of external buffer " << op->buffer->name << " has shape " |
136 | << op->buffer->shape[i] << ", but is only realized for extent " << op->bounds[i]->extent; |
137 | if (auto eq_int = eq.as<IntImmNode>()) { |
138 | ICHECK(eq_int->value) << ss.str(); |
139 | } else { |
140 | stmt = AssertStmt(eq, tvm::tir::StringImm(ss.str()), stmt); |
141 | } |
142 | } |
143 | return stmt; |
144 | } |
145 | |
146 | // Compute the new buffer shape, new realization bounds, and the |
147 | // offsets to be applied to buffer access. |
148 | Array<PrimExpr> realized_shape; |
149 | Array<PrimExpr> index_offsets; |
150 | Array<Range> new_bounds; |
151 | for (size_t i = 0; i < op->bounds.size(); i++) { |
152 | const Range& bound = op->bounds[i]; |
153 | realized_shape.push_back(bound->extent); |
154 | index_offsets.push_back(bound->min); |
155 | new_bounds.push_back({0, bound->extent}); |
156 | } |
157 | |
158 | if (op->buffer->shape.size()) { |
159 | ICHECK_EQ(op->buffer->shape.size(), realized_shape.size()) |
160 | << "Inconsistency between dimension of buffer " << op->buffer |
161 | << " and dimension of its realized bounds." ; |
162 | } |
163 | |
164 | Buffer key = op->buffer; |
165 | |
166 | Buffer buf = op->buffer; |
167 | auto write_ptr = buf.CopyOnWrite(); |
168 | write_ptr->shape = realized_shape; |
169 | |
170 | { |
171 | BufferEntry remap; |
172 | remap.remap_to = buf; |
173 | remap.index_offsets = index_offsets; |
174 | remap.in_scope = true; |
175 | buf_map_[key] = remap; |
176 | } |
177 | |
178 | Stmt stmt = BufferRealize(buf, new_bounds, op->condition, this->VisitStmt(op->body), op->span); |
179 | |
180 | buf_map_.at(key).in_scope = false; |
181 | |
182 | return stmt; |
183 | } |
184 | |
185 | Stmt VisitStmt_(const BufferStoreNode* op) final { |
186 | auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); |
187 | return VisitBufferAccess(std::move(node)); |
188 | } |
189 | |
190 | PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
191 | auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); |
192 | return VisitBufferAccess(std::move(node)); |
193 | } |
194 | |
195 | template <typename Node> |
196 | Node VisitBufferAccess(Node node) { |
197 | auto it = buf_map_.find(node->buffer); |
198 | if (it != buf_map_.end()) { |
199 | const BufferEntry& entry = it->second; |
200 | ICHECK(entry.in_scope) << "Cannot access an out-of-scope buffer" ; |
201 | |
202 | Array<PrimExpr> indices = node->indices; |
203 | if (entry.index_offsets.size()) { |
204 | ICHECK_GE(entry.index_offsets.size(), indices.size()) |
205 | << "Cannot bind buffer to a shape of lower dimension." ; |
206 | |
207 | Array<PrimExpr> new_indices; |
208 | |
209 | // Pad leading indices with zero, matching the "fuzzy_match" |
210 | // behavior from ArgBinder::BindBuffer. |
211 | size_t diff = entry.index_offsets.size() - indices.size(); |
212 | for (size_t i = 0; i < diff; i++) { |
213 | new_indices.push_back(0); |
214 | } |
215 | |
216 | // Offset indices used to access buffers of a reduced size. |
217 | for (size_t i = 0; i < indices.size(); i++) { |
218 | PrimExpr offset = entry.index_offsets[i + diff]; |
219 | new_indices.push_back(indices[i] - offset); |
220 | } |
221 | indices = new_indices; |
222 | } |
223 | |
224 | auto write_ptr = node.CopyOnWrite(); |
225 | write_ptr->indices = indices; |
226 | write_ptr->buffer = entry.remap_to; |
227 | } |
228 | return node; |
229 | } |
230 | |
231 | Stmt VisitStmt_(const AttrStmtNode* op) final { |
232 | if (op->node->IsInstance<tir::BufferNode>()) { |
233 | // Visit body before checking internal_buf_map_, because we |
234 | // don't know if the BufferNode needs to be changed until we |
235 | // look in the body for a BufferRealizeNode with different |
236 | // extents. |
237 | Stmt body = this->VisitStmt(op->body); |
238 | |
239 | Buffer buffer = Downcast<tir::Buffer>(op->node); |
240 | auto it = buf_map_.find(buffer); |
241 | if (it != buf_map_.end()) { |
242 | buffer = it->second.remap_to; |
243 | return AttrStmt(it->second.remap_to, op->attr_key, op->value, body); |
244 | } |
245 | return AttrStmt(buffer, op->attr_key, op->value, body); |
246 | |
247 | } else if (op->attr_key == attr::buffer_bind_scope) { |
248 | return HandleBufferBindScope(op); |
249 | } |
250 | |
251 | return StmtExprMutator::VisitStmt_(op); |
252 | } |
253 | |
254 | private: |
255 | // Any buffers that give views into a resized buffer should be |
256 | // updated, both to refer to the resized buffer and to have the view |
257 | // window updated. For example, suppose B1 is a 1-D buffer of size |
258 | // 100 which is only realized on the range (10,50), and buffer V1 is |
259 | // a view into B1[25:35]. When B1 is replaced with B2, a buffer of |
260 | // size 40 realized on the range (0,40), V1 must be replaced to be a |
261 | // view into B2[15:25]. |
262 | Stmt HandleBufferBindScope(const AttrStmtNode* op) { |
263 | Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node); |
264 | ICHECK_EQ(arr.size(), 2U); |
265 | Buffer buffer = Downcast<Buffer>(arr[0]); |
266 | ICHECK(buffer.defined()); |
267 | Buffer target = Downcast<Buffer>(arr[1]); |
268 | ICHECK(target.defined()); |
269 | |
270 | auto it = buf_map_.find(target); |
271 | ICHECK(it != buf_map_.end()) << "attr::buffer_bind_scope target " << target << " not in scope." ; |
272 | const BufferEntry& target_remap = it->second; |
273 | |
274 | ICHECK(target_remap.in_scope) << "Cannot bind " << buffer->name |
275 | << " to the out-of-scope buffer " << target_remap.remap_to->name; |
276 | |
277 | Call tuple = Downcast<Call>(op->value); |
278 | ICHECK(tuple.defined() && tuple->op.same_as(builtin::tvm_tuple())); |
279 | |
280 | Array<PrimExpr> new_tuple_args; |
281 | Array<PrimExpr> realized_begins; |
282 | Array<PrimExpr> view_shape; |
283 | ICHECK_EQ(tuple->args.size(), target_remap.index_offsets.size() * 2) |
284 | << "attr::buffer_bind_scope to define " << buffer << " as a view into " << target |
285 | << " does match dimensionality of " << target; |
286 | for (size_t i = 0; i < target_remap.index_offsets.size(); i++) { |
287 | PrimExpr parent_begin = tuple->args[2 * i]; |
288 | PrimExpr view_extent = tuple->args[2 * i + 1]; |
289 | // Offset the begin of the buffer view by the offset of the target buffer. |
290 | new_tuple_args.push_back(parent_begin - target_remap.index_offsets[i]); |
291 | // Keep the extent of the buffer view the same. |
292 | new_tuple_args.push_back(view_extent); |
293 | // Use the extent of the buffer view to define the buffer view's shape. |
294 | view_shape.push_back(view_extent); |
295 | // Within the buffer view, indices start at 0. |
296 | realized_begins.push_back(0); |
297 | } |
298 | |
299 | // If a view is binding to a buffer of a higher dimensionality, |
300 | // then the leading dimensions should be padded out with shape of |
301 | // 1. |
302 | ICHECK_GE(view_shape.size(), buffer->shape.size()) |
303 | << "Cannot bind " << buffer << " to a shape of lower dimension." ; |
304 | if (view_shape.size() > buffer->shape.size()) { |
305 | size_t diff = view_shape.size() - buffer->shape.size(); |
306 | Array<PrimExpr> padded_shape; |
307 | for (size_t i = 0; i < diff; i++) { |
308 | padded_shape.push_back(1); |
309 | } |
310 | for (auto dim : buffer->shape) { |
311 | padded_shape.push_back(dim); |
312 | } |
313 | view_shape = std::move(padded_shape); |
314 | } |
315 | |
316 | // If a buffer has strides defined, and is being remapped into a |
317 | // shape with additional dimensions, then define dummy values for |
318 | // the strides. |
319 | Array<PrimExpr> realized_strides = buffer->strides; |
320 | if ((realized_strides.size() > 0) && (realized_strides.size() != view_shape.size())) { |
321 | ICHECK_GE(view_shape.size(), realized_strides.size()) |
322 | << "Cannot bind the strides of " << buffer << " to a shape of lower dimension" ; |
323 | size_t diff = view_shape.size() - buffer->strides.size(); |
324 | |
325 | Array<PrimExpr> updated_strides; |
326 | for (size_t i = 0; i < diff; i++) { |
327 | updated_strides.push_back(Var("stride" , buffer->shape[0].dtype())); |
328 | } |
329 | for (auto stride : buffer->strides) { |
330 | updated_strides.push_back(stride); |
331 | } |
332 | realized_strides = updated_strides; |
333 | } |
334 | |
335 | Buffer key = buffer; |
336 | |
337 | auto write_ptr = buffer.CopyOnWrite(); |
338 | write_ptr->shape = view_shape; |
339 | write_ptr->strides = realized_strides; |
340 | |
341 | { |
342 | BufferEntry remap; |
343 | remap.index_offsets = realized_begins; |
344 | remap.remap_to = buffer; |
345 | remap.in_scope = true; |
346 | buf_map_[key] = remap; |
347 | } |
348 | |
349 | // Define remappings of any Variables referencing Buffer internals |
350 | // (e.g. Store/Load nodes). Passing fuzzy_match=true allows the |
351 | // remapped buffer to have a number of dimensions. |
352 | ArgBinder binder(&var_remap_); |
353 | binder.BindBuffer(key, buffer, key->name, true); |
354 | |
355 | Stmt body = this->VisitStmt(op->body); |
356 | body = MergeNest(binder.asserts(), body); |
357 | body = MergeNest(binder.init_nest(), body); |
358 | |
359 | Stmt stmt = AttrStmt(Array<ObjectRef>{buffer, target_remap.remap_to}, op->attr_key, |
360 | Call(tuple->dtype, tuple->op, new_tuple_args, tuple->span), body); |
361 | |
362 | for (const Var& v : binder.defs()) { |
363 | var_remap_.erase(v.get()); |
364 | } |
365 | |
366 | buf_map_.at(key).in_scope = false; |
367 | return stmt; |
368 | } |
369 | |
370 | std::unordered_map<const VarNode*, PrimExpr> var_remap_; |
371 | |
372 | std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> extern_buffers_; |
373 | |
374 | struct BufferEntry { |
375 | Buffer remap_to; |
376 | Array<PrimExpr> index_offsets; |
377 | bool in_scope; |
378 | }; |
379 | |
380 | std::unordered_map<Buffer, BufferEntry, ObjectPtrHash, ObjectPtrEqual> buf_map_; |
381 | |
382 | IRVisitorWithAnalyzer* bound_analyzer_; |
383 | }; |
384 | |
385 | /* Apply dimension alignment restrictions |
386 | * |
387 | * Buffers annotated with attr::buffer_dim_align may need to have |
388 | * strides defined such that they are no longer in a compact shape. |
389 | * After this pass, buffers have stride definitions to include these |
390 | * alignment restrictions, and attr::buffer_dim_align annotations have |
391 | * been removed. |
392 | */ |
393 | class BufferStrideLegalize : public StmtExprMutator { |
394 | public: |
395 | static transform::Pass Pass() { |
396 | auto pass_func = [](PrimFunc func, IRModule m, transform::PassContext ctx) { |
397 | IRVisitorWithAnalyzer bound_analyzer; |
398 | |
399 | bound_analyzer(func->body); |
400 | |
401 | auto pass = BufferStrideLegalize(func->buffer_map, &bound_analyzer); |
402 | |
403 | auto fptr = func.CopyOnWrite(); |
404 | fptr->body = pass(std::move(fptr->body)); |
405 | fptr->buffer_map = pass.UpdatedExternBufferMap(); |
406 | if (auto map = func->attrs.GetAttr<Map<Buffer, Array<IndexMap>>>("layout_transform_map" )) { |
407 | func = WithAttr(std::move(func), "layout_transform_map" , pass.UpdateIndexMap(map.value())); |
408 | } |
409 | return func; |
410 | }; |
411 | return transform::CreatePrimFuncPass(pass_func, 0, "tir.BufferStrideLegalize" , {}); |
412 | } |
413 | |
414 | explicit BufferStrideLegalize(const Map<Var, Buffer>& extern_buffer_map, |
415 | IRVisitorWithAnalyzer* bound_analyzer) |
416 | : bound_analyzer_(bound_analyzer) { |
417 | for (auto kv : extern_buffer_map) { |
418 | Buffer buf = kv.second; |
419 | Buffer with_strides = WithStrides(buf); |
420 | { |
421 | BufferEntry entry; |
422 | entry.remap_to = with_strides; |
423 | entry.in_scope = true; |
424 | buf_map_[buf] = entry; |
425 | } |
426 | updated_extern_buffer_map_.Set(kv.first, with_strides); |
427 | } |
428 | } |
429 | |
430 | Map<Buffer, Array<IndexMap>> UpdateIndexMap(const Map<Buffer, Array<IndexMap>>& orig) { |
431 | Map<Buffer, Array<IndexMap>> output; |
432 | for (const auto& kv : orig) { |
433 | auto it = buf_map_.find(kv.first); |
434 | if (it != buf_map_.end()) { |
435 | output.Set(it->second.remap_to, kv.second); |
436 | } else { |
437 | output.Set(kv.first, kv.second); |
438 | } |
439 | } |
440 | return output; |
441 | } |
442 | |
443 | Map<Var, Buffer> UpdatedExternBufferMap() const { return updated_extern_buffer_map_; } |
444 | |
445 | Buffer WithStrides(Buffer buf) { |
446 | auto cache_key = buf; |
447 | |
448 | auto it = buf_map_.find(cache_key); |
449 | if (it != buf_map_.end()) { |
450 | const BufferEntry& entry = it->second; |
451 | ICHECK(entry.in_scope) << "Cannot annotate an out-of-scope buffer" ; |
452 | return entry.remap_to; |
453 | } |
454 | |
455 | Array<PrimExpr> shape = buf->shape; |
456 | |
457 | if (buf->strides.size()) { |
458 | ICHECK_EQ(buf->strides.size(), buf->shape.size()) |
459 | << "Buffer " << buf << " has inconsistent strides/shape." ; |
460 | } else if (dim_align_.count(buf) == 0) { |
461 | // Keeping this to have matched behavior to previous version. |
462 | // There are many parts of the codebase that assume that a |
463 | // strided array cannot be compact. For example, |
464 | // ArgBinder::BindBuffer and tir.Specialize. To avoid breaking |
465 | // these, do not define the strides unless required for a |
466 | // non-compact array. |
467 | } else if (shape.size() == 0) { |
468 | // Can't define the strides for a buffer without a known shape. |
469 | } else { |
470 | // With everything checked, can now define the updated strides |
471 | std::vector<PrimExpr> rstrides; |
472 | const std::vector<DimAlignInfo>& avec = dim_align_[buf]; |
473 | int first_dim = 0; |
474 | PrimExpr stride = make_const(shape[first_dim].dtype(), 1); |
475 | for (size_t i = shape.size(); i != 0; --i) { |
476 | size_t dim = i - 1; |
477 | if (dim < avec.size() && avec[dim].align_factor != 0) { |
478 | PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor); |
479 | PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset); |
480 | stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor); |
481 | stride = bound_analyzer_->Simplify(stride); |
482 | } |
483 | rstrides.push_back(stride); |
484 | stride = stride * shape[dim]; |
485 | } |
486 | |
487 | buf.CopyOnWrite()->strides = Array<PrimExpr>(rstrides.rbegin(), rstrides.rend()); |
488 | } |
489 | |
490 | BufferEntry entry; |
491 | entry.remap_to = buf; |
492 | entry.in_scope = true; |
493 | buf_map_[cache_key] = entry; |
494 | |
495 | return buf; |
496 | } |
497 | |
498 | Stmt VisitStmt_(const AttrStmtNode* op) final { |
499 | if (op->attr_key == attr::buffer_dim_align) { |
500 | auto buffer = Downcast<tir::Buffer>(op->node); |
501 | const CallNode* tuple = op->value.as<CallNode>(); |
502 | ICHECK(tuple && tuple->op.same_as(builtin::tvm_tuple())); |
503 | auto& vinfo = dim_align_[buffer]; |
504 | int dim = tuple->args[0].as<IntImmNode>()->value; |
505 | if (static_cast<size_t>(dim) >= vinfo.size()) { |
506 | vinfo.resize(dim + 1); |
507 | } |
508 | vinfo[dim].align_factor = tuple->args[1].as<IntImmNode>()->value; |
509 | vinfo[dim].align_offset = tuple->args[2].as<IntImmNode>()->value; |
510 | |
511 | return this->VisitStmt(op->body); |
512 | } else if (op->attr_key == attr::buffer_bind_scope) { |
513 | Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node); |
514 | ICHECK_EQ(arr.size(), 2U); |
515 | Buffer source = Downcast<Buffer>(arr[0]); |
516 | Buffer target_with_strides = WithStrides(Downcast<Buffer>(arr[1])); |
517 | Buffer source_with_strides = WithStrides(source); |
518 | |
519 | Stmt body = this->VisitStmt(op->body); |
520 | |
521 | buf_map_[source].in_scope = false; |
522 | |
523 | return AttrStmt(Array<ObjectRef>{source_with_strides, target_with_strides}, op->attr_key, |
524 | op->value, body, op->span); |
525 | } else { |
526 | return StmtExprMutator::VisitStmt_(op); |
527 | } |
528 | } |
529 | |
530 | // AllocateNodes may be present from tvm.tir.ir_builder. This can |
531 | // be simplified in the future by having AllocateNode hold a buffer, |
532 | // rather than a buffer_var. |
533 | Stmt VisitStmt_(const AllocateNode* op) final { |
534 | buffer_var_defines_.insert(op->buffer_var.get()); |
535 | return StmtExprMutator::VisitStmt_(op); |
536 | } |
537 | |
538 | Stmt VisitStmt_(const AllocateConstNode* op) final { |
539 | buffer_var_defines_.insert(op->buffer_var.get()); |
540 | return StmtExprMutator::VisitStmt_(op); |
541 | } |
542 | |
543 | Stmt VisitStmt_(const LetStmtNode* op) final { |
544 | if (op->var.dtype().is_handle()) { |
545 | buffer_var_defines_.insert(op->var.get()); |
546 | } |
547 | return StmtExprMutator::VisitStmt_(op); |
548 | } |
549 | |
550 | PrimExpr VisitExpr_(const LetNode* op) final { |
551 | if (op->var.dtype().is_handle()) { |
552 | buffer_var_defines_.insert(op->var.get()); |
553 | } |
554 | return StmtExprMutator::VisitExpr_(op); |
555 | } |
556 | |
557 | Stmt VisitStmt_(const BufferRealizeNode* op) final { |
558 | Buffer key = op->buffer; |
559 | Buffer with_strides = WithStrides(op->buffer); |
560 | |
561 | Stmt stmt = StmtExprMutator::VisitStmt_(op); |
562 | |
563 | buf_map_[key].in_scope = false; |
564 | op = stmt.as<BufferRealizeNode>(); |
565 | ICHECK(op); |
566 | |
567 | return BufferRealize(with_strides, op->bounds, op->condition, op->body, op->span); |
568 | } |
569 | |
570 | Stmt VisitStmt_(const BufferStoreNode* op) final { |
571 | auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); |
572 | return VisitBufferAccess(std::move(node)); |
573 | } |
574 | |
575 | PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
576 | auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); |
577 | return VisitBufferAccess(std::move(node)); |
578 | } |
579 | |
580 | template <typename Node> |
581 | Node VisitBufferAccess(Node node) { |
582 | auto it = buf_map_.find(node->buffer); |
583 | ICHECK(it == buf_map_.end() || it->second.in_scope) |
584 | << "Cannot access a buffer " << node->buffer->name << ", out of scope" ; |
585 | |
586 | auto with_strides = WithStrides(node->buffer); |
587 | if (!with_strides.same_as(node->buffer)) { |
588 | node.CopyOnWrite()->buffer = with_strides; |
589 | } |
590 | |
591 | return node; |
592 | } |
593 | |
594 | private: |
595 | Map<Var, Buffer> updated_extern_buffer_map_; |
596 | |
597 | struct DimAlignInfo { |
598 | int align_factor{0}; |
599 | int align_offset{0}; |
600 | }; |
601 | |
602 | // Dimension alignment |
603 | std::unordered_map<Buffer, std::vector<DimAlignInfo>, ObjectPtrHash, ObjectPtrEqual> dim_align_; |
604 | |
605 | struct BufferEntry { |
606 | Buffer remap_to; |
607 | bool in_scope; |
608 | }; |
609 | |
610 | std::unordered_map<Buffer, BufferEntry, ObjectPtrHash, ObjectPtrEqual> buf_map_; |
611 | |
612 | // Set of vars that have occurred in an AllocateNode, but haven't |
613 | // yet occurred in a BufferLoad/BufferStore. |
614 | std::unordered_set<const VarNode*> buffer_var_defines_; |
615 | |
616 | IRVisitorWithAnalyzer* bound_analyzer_; |
617 | }; |
618 | |
619 | /* Use the scope of IterVar to determine storage scope. |
620 | * |
621 | * For buffers that do not have an explicit storage scope defined, a |
622 | * reasonable storage scope may be defined based on the thread scope |
623 | * that contains the buffer's allocation. All other buffers without a |
624 | * scope are assigned to global scope. |
625 | */ |
626 | class ThreadScopePropagate : public StmtExprMutator { |
627 | public: |
628 | static transform::Pass Pass() { |
629 | auto pass_func = [](PrimFunc func, IRModule m, transform::PassContext ctx) { |
630 | auto pass = ThreadScopePropagate(func->buffer_map); |
631 | |
632 | auto fptr = func.CopyOnWrite(); |
633 | fptr->body = pass(std::move(fptr->body)); |
634 | if (auto map = func->attrs.GetAttr<Map<Buffer, Array<IndexMap>>>("layout_transform_map" )) { |
635 | func = WithAttr(std::move(func), "layout_transform_map" , pass.UpdateIndexMap(map.value())); |
636 | } |
637 | return func; |
638 | }; |
639 | return transform::CreatePrimFuncPass(pass_func, 0, "tir.ThreadScopePropagate" , {}); |
640 | } |
641 | |
642 | explicit ThreadScopePropagate(const Map<Var, Buffer>& extern_buffer_map) { |
643 | // External buffers shouldn't be overwritten, even if they have a |
644 | // BufferRealizeNode. |
645 | for (auto kv : extern_buffer_map) { |
646 | external_buffers_.insert(kv.second); |
647 | } |
648 | } |
649 | |
650 | Map<Buffer, Array<IndexMap>> UpdateIndexMap(const Map<Buffer, Array<IndexMap>>& orig) { |
651 | Map<Buffer, Array<IndexMap>> output; |
652 | for (const auto& kv : orig) { |
653 | auto it = buf_remap_.find(kv.first->data); |
654 | if (it != buf_remap_.end()) { |
655 | output.Set(it->second, kv.second); |
656 | } else { |
657 | output.Set(kv.first, kv.second); |
658 | } |
659 | } |
660 | return output; |
661 | } |
662 | |
663 | PrimExpr VisitExpr_(const VarNode* op) final { |
664 | auto it = buf_remap_.find(GetRef<Var>(op)); |
665 | if (it != buf_remap_.end()) { |
666 | return it->second->data; |
667 | } else { |
668 | return GetRef<PrimExpr>(op); |
669 | } |
670 | } |
671 | |
672 | Stmt VisitStmt_(const AttrStmtNode* op) final { |
673 | ICHECK_NE(op->attr_key, attr::buffer_dim_align) |
674 | << "StorageFlattener assumes that all buffers have accurate strides, " |
675 | << "and all buffer_dim_align annotations are removed. " |
676 | << "Please run BufferStrideLegalize first." ; |
677 | |
678 | if (op->attr_key == attr::thread_extent) { |
679 | IterVar iv = Downcast<IterVar>(op->node); |
680 | ThreadScope ts = ThreadScope::Create(iv->thread_tag); |
681 | curr_thread_scope_.push_back(ts); |
682 | Stmt stmt = StmtExprMutator::VisitStmt_(op); |
683 | curr_thread_scope_.pop_back(); |
684 | return stmt; |
685 | } else if (op->attr_key == attr::buffer_bind_scope) { |
686 | return HandleBufferBindScope(op); |
687 | } else { |
688 | return StmtExprMutator::VisitStmt_(op); |
689 | } |
690 | } |
691 | |
692 | Stmt VisitStmt_(const BufferRealizeNode* op) final { |
693 | Var old_var = op->buffer->data; |
694 | |
695 | // Don't remap buffers that already have an explicit scope, |
696 | // or external buffers. |
697 | std::string str_scope = GetPtrStorageScope(old_var); |
698 | if ((str_scope.length() > 0) || external_buffers_.count(op->buffer)) { |
699 | return StmtExprMutator::VisitStmt_(op); |
700 | } |
701 | |
702 | ICHECK_EQ(buf_remap_.count(old_var), 0) |
703 | << "Buffer var " << op->buffer->data << " appears in multiple BufferRealize nodes" ; |
704 | |
705 | StorageScope skey; |
706 | if (curr_thread_scope_.size() == 0) { |
707 | skey.rank = StorageRank::kGlobal; |
708 | } else { |
709 | skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank); |
710 | } |
711 | |
712 | auto ptr_type = old_var->type_annotation.as<PointerTypeNode>(); |
713 | ICHECK(ptr_type); |
714 | Var new_var(old_var->name_hint, PointerType(ptr_type->element_type, skey.to_string()), |
715 | old_var->span); |
716 | |
717 | Buffer buf = op->buffer; |
718 | buf.CopyOnWrite()->data = new_var; |
719 | |
720 | buf_remap_[old_var] = buf; |
721 | |
722 | Stmt body = this->VisitStmt(op->body); |
723 | return BufferRealize(buf, op->bounds, op->condition, body, op->span); |
724 | } |
725 | |
726 | PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
727 | PrimExpr expr = StmtExprMutator::VisitExpr_(op); |
728 | op = expr.as<BufferLoadNode>(); |
729 | ICHECK(op); |
730 | |
731 | auto it = buf_remap_.find(op->buffer->data); |
732 | if (it != buf_remap_.end()) { |
733 | return BufferLoad(it->second, op->indices, op->span); |
734 | } else { |
735 | return expr; |
736 | } |
737 | } |
738 | |
739 | Stmt VisitStmt_(const BufferStoreNode* op) final { |
740 | Stmt stmt = StmtExprMutator::VisitStmt_(op); |
741 | op = stmt.as<BufferStoreNode>(); |
742 | ICHECK(op); |
743 | |
744 | auto it = buf_remap_.find(op->buffer->data); |
745 | if (it != buf_remap_.end()) { |
746 | return BufferStore(it->second, op->value, op->indices, op->span); |
747 | } else { |
748 | return stmt; |
749 | } |
750 | } |
751 | |
752 | private: |
753 | // If the rewritten buffers are part of a buffer_bind_scope, either |
754 | // as the buffer view or as the buffer being viewed, then the |
755 | // buffer_bind_scope must be rewritten to refer to the updated |
756 | // buffers. |
757 | Stmt HandleBufferBindScope(const AttrStmtNode* op) { |
758 | Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node); |
759 | ICHECK_EQ(arr.size(), 2U); |
760 | Buffer buffer = Downcast<Buffer>(arr[0]); |
761 | ICHECK(buffer.defined()); |
762 | Buffer target = Downcast<Buffer>(arr[1]); |
763 | ICHECK(target.defined()); |
764 | |
765 | bool needs_rewrite = false; |
766 | |
767 | { |
768 | auto it = buf_remap_.find(buffer->data); |
769 | if (it != buf_remap_.end()) { |
770 | needs_rewrite = true; |
771 | buffer = it->second; |
772 | } |
773 | } |
774 | |
775 | { |
776 | auto it = buf_remap_.find(target->data); |
777 | if (it != buf_remap_.end()) { |
778 | needs_rewrite = true; |
779 | target = it->second; |
780 | } |
781 | } |
782 | |
783 | if (needs_rewrite) { |
784 | Stmt body = this->VisitStmt(op->body); |
785 | return AttrStmt(Array<ObjectRef>{buffer, target}, op->attr_key, op->value, body); |
786 | } else { |
787 | return StmtExprMutator::VisitStmt_(op); |
788 | } |
789 | } |
790 | |
791 | std::unordered_map<Var, Buffer, ObjectPtrHash, ObjectPtrEqual> buf_remap_; |
792 | std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> external_buffers_; |
793 | |
794 | // The current thread scope. |
795 | std::vector<ThreadScope> curr_thread_scope_; |
796 | }; |
797 | |
798 | /* Map buffer binds to their source buffer |
799 | * |
800 | * Buffers defined using an attr::buffer_bind_scope annotation are |
801 | * views into some linked buffer, potentially into some restricted |
802 | * subregion of that buffer. This pass identifies such buffers, then |
803 | * rewrites all access of the bound buffers to be access into the |
804 | * linked buffer. |
805 | */ |
806 | class BufferBindUnwrapper : public StmtExprMutator { |
807 | public: |
808 | static transform::Pass Pass() { |
809 | auto pass_func = [](PrimFunc func, IRModule m, transform::PassContext ctx) { |
810 | IRVisitorWithAnalyzer bound_analyzer; |
811 | |
812 | bound_analyzer(func->body); |
813 | |
814 | auto pass = BufferBindUnwrapper(func->buffer_map, &bound_analyzer); |
815 | |
816 | auto fptr = func.CopyOnWrite(); |
817 | fptr->body = pass(std::move(fptr->body)); |
818 | return func; |
819 | }; |
820 | return transform::CreatePrimFuncPass(pass_func, 0, "tir.BufferBindUnwrapper" , {}); |
821 | } |
822 | |
823 | explicit BufferBindUnwrapper(const Map<Var, Buffer>& extern_buffer_map, |
824 | IRVisitorWithAnalyzer* bound_analyzer) |
825 | : bound_analyzer_(bound_analyzer) { |
826 | for (auto kv : extern_buffer_map) { |
827 | BufferEntry e; |
828 | e.buffer = kv.second; |
829 | e.external = true; |
830 | var_to_buffer_[kv.second->data.get()] = kv.second; |
831 | buf_map_[kv.second.get()] = std::move(e); |
832 | } |
833 | } |
834 | |
835 | Map<Buffer, Array<IndexMap>> UpdateIndexMap(const Map<Buffer, Array<IndexMap>>& orig) { |
836 | Map<Buffer, Array<IndexMap>> output; |
837 | for (const auto& kv : orig) { |
838 | const BufferEntry& e = GetBufferEntry(kv.first); |
839 | |
840 | if (e.remap) { |
841 | output.Set(e.remap->target, kv.second); |
842 | } else { |
843 | output.Set(kv.first, kv.second); |
844 | } |
845 | } |
846 | return output; |
847 | } |
848 | |
849 | Stmt VisitStmt_(const StoreNode* op) final { |
850 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
851 | } |
852 | |
853 | PrimExpr VisitExpr_(const LoadNode* op) final { |
854 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
855 | } |
856 | |
857 | Stmt VisitStmt_(const AttrStmtNode* op) final { |
858 | ICHECK_NE(op->attr_key, attr::buffer_dim_align) |
859 | << "BufferBindUnwrapper assumes that all buffers have accurate strides, " |
860 | << "and all buffer_dim_align annotations are removed. " |
861 | << "Please run BufferStrideLegalize first." ; |
862 | |
863 | if (op->attr_key == attr::buffer_bind_scope) { |
864 | return HandleBufferBindScope(op); |
865 | } else { |
866 | return StmtExprMutator::VisitStmt_(op); |
867 | } |
868 | } |
869 | |
870 | PrimExpr VisitExpr_(const VarNode* op) final { |
871 | ICHECK(!illegal_vars_.count(op)) << "Variable " << op->name_hint << " is not well defined. " |
872 | << "(e.g. use of buffer.elem_offset for a non-flat buffer)" ; |
873 | |
874 | auto it = var_remap_.find(op); |
875 | if (it != var_remap_.end()) { |
876 | return it->second; |
877 | } else { |
878 | return GetRef<PrimExpr>(op); |
879 | } |
880 | } |
881 | |
882 | Array<PrimExpr> remap_indices(Array<PrimExpr> indices, Array<PrimExpr> begins, |
883 | Array<PrimExpr> extents) { |
884 | ICHECK_EQ(begins.size(), extents.size()); |
885 | |
886 | if (begins.size() == 0) { |
887 | return indices; |
888 | } |
889 | |
890 | ICHECK_EQ(begins.size(), indices.size()); |
891 | |
892 | Array<PrimExpr> out; |
893 | for (size_t i = 0; i < begins.size(); i++) { |
894 | out.push_back(begins[i] + indices[i]); |
895 | } |
896 | return out; |
897 | } |
898 | |
899 | Array<Range> remap_bounds(Array<Range> bounds, Array<PrimExpr> begins, Array<PrimExpr> extents) { |
900 | ICHECK_EQ(begins.size(), extents.size()); |
901 | |
902 | if (begins.size() == 0) { |
903 | return bounds; |
904 | } |
905 | |
906 | ICHECK_EQ(begins.size(), bounds.size()); |
907 | |
908 | Array<Range> out; |
909 | for (size_t i = 0; i < begins.size(); i++) { |
910 | out.push_back(Range::FromMinExtent(bounds[i]->min + begins[i], bounds[i]->extent)); |
911 | } |
912 | return out; |
913 | } |
914 | |
915 | // AllocateNodes may be present from tvm.tir.ir_builder. This can |
916 | // be simplified in the future by having AllocateNode hold a buffer, |
917 | // rather than a buffer_var. |
918 | Stmt VisitStmt_(const AllocateNode* op) final { |
919 | buffer_var_defines_.insert(op->buffer_var.get()); |
920 | return StmtExprMutator::VisitStmt_(op); |
921 | } |
922 | |
923 | Stmt VisitStmt_(const AllocateConstNode* op) final { |
924 | buffer_var_defines_.insert(op->buffer_var.get()); |
925 | return StmtExprMutator::VisitStmt_(op); |
926 | } |
927 | |
928 | Stmt VisitStmt_(const LetStmtNode* op) final { |
929 | if (op->var.dtype().is_handle()) { |
930 | buffer_var_defines_.insert(op->var.get()); |
931 | } |
932 | return StmtExprMutator::VisitStmt_(op); |
933 | } |
934 | |
935 | PrimExpr VisitExpr_(const LetNode* op) final { |
936 | if (op->var.dtype().is_handle()) { |
937 | buffer_var_defines_.insert(op->var.get()); |
938 | } |
939 | return StmtExprMutator::VisitExpr_(op); |
940 | } |
941 | |
942 | PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
943 | PrimExpr expr = StmtExprMutator::VisitExpr_(op); |
944 | op = expr.as<BufferLoadNode>(); |
945 | |
946 | const BufferEntry& e = GetBufferEntry(op->buffer); |
947 | |
948 | if (e.remap) { |
949 | return BufferLoad(e.remap->target, |
950 | remap_indices(op->indices, e.remap->begins, e.remap->extents), op->span); |
951 | } else { |
952 | return expr; |
953 | } |
954 | } |
955 | |
956 | Stmt VisitStmt_(const BufferStoreNode* op) final { |
957 | Stmt stmt = StmtExprMutator::VisitStmt_(op); |
958 | op = stmt.as<BufferStoreNode>(); |
959 | |
960 | const BufferEntry& e = GetBufferEntry(op->buffer); |
961 | |
962 | if (e.remap) { |
963 | return BufferStore(e.remap->target, op->value, |
964 | remap_indices(op->indices, e.remap->begins, e.remap->extents), op->span); |
965 | } else { |
966 | return stmt; |
967 | } |
968 | } |
969 | |
970 | Stmt VisitStmt_(const BufferRealizeNode* op) final { |
971 | const auto& key = op->buffer.get(); |
972 | |
973 | bool is_external = false; |
974 | |
975 | if (buf_map_.count(key)) { |
976 | ICHECK(buf_map_.at(key).external) |
977 | << "BufferRealize node for internal buffer " << op->buffer << " occurred multiple times." ; |
978 | |
979 | is_external = true; |
980 | } else { |
981 | BufferEntry e; |
982 | e.bounds = op->bounds; |
983 | e.buffer = op->buffer; |
984 | var_to_buffer_[op->buffer->data.get()] = op->buffer; |
985 | buf_map_[key] = std::move(e); |
986 | } |
987 | |
988 | Stmt stmt = StmtExprMutator::VisitStmt_(op); |
989 | |
990 | if (is_external) { |
991 | buf_map_[key].in_scope = false; |
992 | } |
993 | |
994 | return stmt; |
995 | } |
996 | |
997 | Stmt VisitStmt_(const PrefetchNode* op) final { |
998 | Stmt stmt = StmtExprMutator::VisitStmt_(op); |
999 | op = stmt.as<PrefetchNode>(); |
1000 | ICHECK(op != nullptr); |
1001 | |
1002 | const BufferEntry& e = GetBufferEntry(op->buffer); |
1003 | |
1004 | ICHECK(e.in_scope) << "Read a buffer that is already out of scope" ; |
1005 | ICHECK_EQ(e.buffer->shape.size(), op->bounds.size()) |
1006 | << "Prefetch dim should be the same as buffer dim" ; |
1007 | |
1008 | if (e.remap) { |
1009 | return Prefetch(e.remap->target, remap_bounds(op->bounds, e.remap->begins, e.remap->extents), |
1010 | op->span); |
1011 | } else { |
1012 | return stmt; |
1013 | } |
1014 | } |
1015 | |
1016 | private: |
1017 | // Read the mapping from a buffer view to the actual buffer. This |
1018 | // allows all later BufferStore/BufferLoad nodes to reference the |
1019 | // actual buffer, rather than the buffer view. |
1020 | Stmt HandleBufferBindScope(const AttrStmtNode* op) { |
1021 | // Unpack information from Attribute node |
1022 | RemapInfo remap; |
1023 | |
1024 | Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node); |
1025 | ICHECK_EQ(arr.size(), 2U); |
1026 | const Buffer source = Downcast<Buffer>(arr[0]); |
1027 | ICHECK(source.defined()); |
1028 | remap.target = Downcast<Buffer>(arr[1]); |
1029 | ICHECK(remap.target.defined()); |
1030 | const CallNode* tuple = op->value.as<CallNode>(); |
1031 | ICHECK(tuple && tuple->op.same_as(builtin::tvm_tuple())); |
1032 | |
1033 | for (size_t i = 0; i < tuple->args.size(); i += 2) { |
1034 | remap.begins.push_back(tuple->args[i]); |
1035 | remap.extents.push_back(tuple->args[i + 1]); |
1036 | } |
1037 | |
1038 | // Determine bounds in the target buffer |
1039 | auto it = buf_map_.find(remap.target.get()); |
1040 | ICHECK(it != buf_map_.end()) << "Cannot define " << source << " as a view into " << remap.target |
1041 | << ", " << remap.target << " was not defined." ; |
1042 | const BufferEntry& target_info = it->second; |
1043 | ICHECK(target_info.in_scope) << "Cannot define " << source << " as a view into " << remap.target |
1044 | << ", " << remap.target << " is out of scope." ; |
1045 | ICHECK_EQ(remap.begins.size(), target_info.buffer->shape.size()) |
1046 | << "Incorrect number of arguments in buffer_bind_scope attribute. " |
1047 | << "Expected (min_0, extent_0, min_1, extent_0, ..., min_N, extent_N)." ; |
1048 | |
1049 | if (target_info.bounds.size() > 0) { |
1050 | Array<PrimExpr> mapped_begins; |
1051 | for (size_t i = 0; i < target_info.buffer->shape.size(); ++i) { |
1052 | mapped_begins.push_back(remap.begins[i] - target_info.bounds[i]->min); |
1053 | } |
1054 | remap.begins = std::move(mapped_begins); |
1055 | } |
1056 | |
1057 | ICHECK(target_info.remap == nullptr) |
1058 | << "buffer_bind_scope defines " << source << " as a view into " << remap.target |
1059 | << ", which is itself a buffer view. " |
1060 | << "Indirect remapping not currently supported." ; |
1061 | |
1062 | for (size_t i = 0; i < remap.begins.size(); i++) { |
1063 | remap.begins.Set(i, bound_analyzer_->Simplify(remap.begins[i])); |
1064 | remap.extents.Set(i, bound_analyzer_->Simplify(remap.extents[i])); |
1065 | } |
1066 | |
1067 | // Add a buffer remap entry |
1068 | { |
1069 | BufferEntry source_info; |
1070 | source_info.buffer = source; |
1071 | source_info.remap = std::make_unique<RemapInfo>(remap); |
1072 | |
1073 | var_to_buffer_[source->data.get()] = source; |
1074 | buf_map_[source.get()] = std::move(source_info); |
1075 | } |
1076 | |
1077 | // Define remappings of any remaining Variables (e.g. Store/Load nodes). |
1078 | ArgBinder binder(&var_remap_); |
1079 | |
1080 | // Define a view that represents the source's view into the target |
1081 | // buffer. This Buffer object is only used to define the mapping |
1082 | // to the target buffer, and never actually appears in the TIR |
1083 | // graph. |
1084 | Buffer view = remap.target.MakeSlice(remap.begins, remap.extents); |
1085 | if (source->strides.size() == 0) { |
1086 | ICHECK_EQ(view->strides.size(), 0U) |
1087 | << "Cannot bind a compact buffer " << source << " to a strided buffer " << view |
1088 | << " with strides " << view->strides; |
1089 | } else { |
1090 | // Add explicit strides to the view, in order to bind to source.strides[i]. |
1091 | view = view.MakeStrideView(); |
1092 | } |
1093 | |
1094 | // Match integer bits of source->elem_offset and view->elem_offset |
1095 | // as is required by ArgBinder::Bind_ |
1096 | if (view->elem_offset.defined() && source->elem_offset.dtype() != view->elem_offset.dtype()) { |
1097 | view.CopyOnWrite()->elem_offset = cast(source->elem_offset.dtype(), view->elem_offset); |
1098 | } |
1099 | |
1100 | // Bind any variables that reference the view (e.g. elem_offset, |
1101 | // strides, shape). Pass fuzzy_match=false, because all shape |
1102 | // transformations should have been handled in |
1103 | // BufferShapeLegalize. |
1104 | binder.BindBuffer(source, view, source->name, false); |
1105 | if (auto* elem_offset_var = source->elem_offset.as<VarNode>()) { |
1106 | if (!view->elem_offset.defined()) { |
1107 | illegal_vars_.insert(elem_offset_var); |
1108 | } |
1109 | } |
1110 | |
1111 | // Apply the remaps |
1112 | Stmt body = op->body; |
1113 | body = MergeNest(binder.asserts(), body); |
1114 | body = MergeNest(binder.init_nest(), body); |
1115 | body = this->VisitStmt(body); |
1116 | // remove the binds |
1117 | for (const Var& v : binder.defs()) { |
1118 | var_remap_.erase(v.get()); |
1119 | } |
1120 | return body; |
1121 | } |
1122 | |
1123 | struct RemapInfo { |
1124 | Buffer target; |
1125 | Array<PrimExpr> begins; |
1126 | Array<PrimExpr> extents; |
1127 | }; |
1128 | |
1129 | // The buffer entry in the flatten map |
1130 | struct BufferEntry { |
1131 | // The storage buffer |
1132 | Buffer buffer; |
1133 | // the bounds of realization, can be null, means everything |
1134 | Region bounds; |
1135 | // Whether the buffer is external |
1136 | bool external{false}; |
1137 | // Whether we are within the allocation scope of the buffer. |
1138 | bool in_scope{true}; |
1139 | |
1140 | // The buffer to which the storage buffer should be remapped. |
1141 | std::unique_ptr<RemapInfo> remap{nullptr}; |
1142 | }; |
1143 | |
1144 | const BufferEntry& GetBufferEntry(Buffer buffer) { |
1145 | if (buf_map_.count(buffer.get())) { |
1146 | const BufferEntry& e = buf_map_[buffer.get()]; |
1147 | ICHECK(e.in_scope) << "Cannot access a buffer " << buffer->name << ", out of scope" ; |
1148 | return e; |
1149 | } else if (buffer_var_defines_.count(buffer->data.get())) { |
1150 | // The buffer var was defined, but the buffer hasn't been seen |
1151 | // before. |
1152 | BufferEntry entry; |
1153 | entry.buffer = buffer; |
1154 | var_to_buffer_[buffer->data.get()] = buffer; |
1155 | buf_map_[buffer.get()] = std::move(entry); |
1156 | return buf_map_[buffer.get()]; |
1157 | } else if (var_remap_.count(buffer->data.get())) { |
1158 | // The buffer var is an alias of a bound buffer. Only |
1159 | // supported if the bound buffer has no offsets. In this |
1160 | // case, we just need to make a new aliasing buffer that |
1161 | // shares the remapped data variable. |
1162 | Var old_var = buffer->data; |
1163 | Var new_var = Downcast<Var>(var_remap_[old_var.get()]); |
1164 | |
1165 | { |
1166 | ICHECK(var_to_buffer_.count(old_var.get())) |
1167 | << "Cannot find remap information for aliased buffer var " << old_var->name_hint |
1168 | << ", required to verify this alias is legal." ; |
1169 | const Buffer& aliased_buffer = var_to_buffer_[old_var.get()]; |
1170 | const BufferEntry& entry = buf_map_[aliased_buffer.get()]; |
1171 | if (entry.remap) { |
1172 | for (const auto& begin : entry.remap->begins) { |
1173 | ICHECK(is_zero(begin)) << "Aliasing of buffer with offset is not supported" ; |
1174 | } |
1175 | } |
1176 | } |
1177 | |
1178 | { |
1179 | Buffer new_buf = buffer; |
1180 | new_buf.CopyOnWrite()->data = new_var; |
1181 | |
1182 | RemapInfo remap_info; |
1183 | remap_info.target = new_buf; |
1184 | remap_info.begins = Array<PrimExpr>(buffer->shape.size(), 0); |
1185 | remap_info.extents = buffer->shape; |
1186 | |
1187 | BufferEntry entry; |
1188 | entry.buffer = buffer; |
1189 | entry.remap = std::make_unique<RemapInfo>(remap_info); |
1190 | entry.in_scope = true; |
1191 | var_to_buffer_[buffer->data.get()] = buffer; |
1192 | buf_map_[buffer.get()] = std::move(entry); |
1193 | } |
1194 | return buf_map_[buffer.get()]; |
1195 | } else if (var_to_buffer_.count(buffer->data.get())) { |
1196 | // This buffer is an alias of a known buffer, with no remaps. A |
1197 | // buffer entry should be generated and returned. |
1198 | BufferEntry entry; |
1199 | entry.buffer = buffer; |
1200 | entry.in_scope = true; |
1201 | var_to_buffer_[buffer->data.get()] = buffer; |
1202 | buf_map_[buffer.get()] = std::move(entry); |
1203 | |
1204 | return buf_map_[buffer.get()]; |
1205 | } else { |
1206 | LOG(FATAL) << "Can't work around the undefined buffer" ; |
1207 | } |
1208 | } |
1209 | |
1210 | // The buffer assignment map |
1211 | // Variable remap |
1212 | std::unordered_map<const VarNode*, PrimExpr> var_remap_; |
1213 | // Variables that may not occur within the body. |
1214 | std::unordered_set<const VarNode*> illegal_vars_; |
1215 | // Buffer map |
1216 | std::unordered_map<const BufferNode*, BufferEntry> buf_map_; |
1217 | // Map from Var to the Buffer they occurred in. In case of aliased |
1218 | // buffers, contains the first buffer. |
1219 | std::unordered_map<const VarNode*, Buffer> var_to_buffer_; |
1220 | // Set of vars that have occurred in an AllocateNode, but haven't |
1221 | // yet occurred in a BufferLoad/BufferStore. |
1222 | std::unordered_set<const VarNode*> buffer_var_defines_; |
1223 | // Analyzer for the variable bounds, used to simplify the bounds populator. We really need the |
1224 | // analyzer from it. However |
1225 | IRVisitorWithAnalyzer* bound_analyzer_; |
1226 | }; |
1227 | |
1228 | class ApplyLayoutTransforms : public StmtExprMutator { |
1229 | public: |
1230 | static transform::Pass Pass() { |
1231 | auto pass_func = [](PrimFunc func, IRModule m, transform::PassContext ctx) { |
1232 | auto lookup = func->attrs.GetAttr<Map<Buffer, Array<IndexMap>>>("layout_transform_map" ); |
1233 | |
1234 | if (!lookup) { |
1235 | return func; |
1236 | } |
1237 | |
1238 | Map<Buffer, Array<IndexMap>> layout_transforms = lookup.value(); |
1239 | |
1240 | auto fptr = func.CopyOnWrite(); |
1241 | |
1242 | auto mutator = ApplyLayoutTransforms(layout_transforms); |
1243 | fptr->buffer_map = mutator.UpdateExternBufferMap(fptr->buffer_map); |
1244 | fptr->body = mutator(std::move(fptr->body)); |
1245 | |
1246 | return WithoutAttr(std::move(func), "layout_transform_map" ); |
1247 | }; |
1248 | return transform::CreatePrimFuncPass(pass_func, 0, "tir.ApplyLayoutTransforms" , {}); |
1249 | } |
1250 | |
1251 | explicit ApplyLayoutTransforms(Map<Buffer, Array<IndexMap>> layout_transforms) |
1252 | : layout_transforms_(layout_transforms) {} |
1253 | |
1254 | Map<tir::Var, Buffer> UpdateExternBufferMap(const Map<tir::Var, Buffer>& buffer_map) { |
1255 | Map<tir::Var, Buffer> output; |
1256 | for (const auto& kv : buffer_map) { |
1257 | output.Set(kv.first, GetBufferRemap(kv.second, true)); |
1258 | } |
1259 | return output; |
1260 | } |
1261 | |
1262 | Stmt VisitStmt_(const BufferRealizeNode* op) final { |
1263 | // Call once so that load/store nodes can read from the cached |
1264 | // value. |
1265 | GetBufferRemap(op->buffer, true); |
1266 | |
1267 | auto realize = Downcast<BufferRealize>(StmtExprMutator::VisitStmt_(op)); |
1268 | |
1269 | auto lookup = layout_transforms_.Get(op->buffer); |
1270 | if (lookup) { |
1271 | auto write_ptr = realize.CopyOnWrite(); |
1272 | write_ptr->buffer = GetBufferRemap(op->buffer, true); |
1273 | |
1274 | Array<IndexMap> transforms = lookup.value(); |
1275 | for (const auto& transform : transforms) { |
1276 | write_ptr->bounds = transform->MapRanges(realize->bounds); |
1277 | } |
1278 | } |
1279 | |
1280 | return std::move(realize); |
1281 | } |
1282 | |
1283 | Stmt VisitStmt_(const BufferStoreNode* op) final { |
1284 | auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); |
1285 | return VisitBufferAccess(std::move(node)); |
1286 | } |
1287 | |
1288 | PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
1289 | auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); |
1290 | return VisitBufferAccess(std::move(node)); |
1291 | } |
1292 | |
1293 | template <typename Node> |
1294 | Node VisitBufferAccess(Node node) { |
1295 | auto lookup = layout_transforms_.Get(node->buffer); |
1296 | if (lookup) { |
1297 | auto write_ptr = node.CopyOnWrite(); |
1298 | |
1299 | write_ptr->buffer = GetBufferRemap(node->buffer); |
1300 | |
1301 | Array<IndexMap> transforms = lookup.value(); |
1302 | for (const auto& transform : transforms) { |
1303 | write_ptr->indices = transform->MapIndices(node->indices); |
1304 | } |
1305 | } |
1306 | return node; |
1307 | } |
1308 | |
1309 | private: |
1310 | //! \brief Given a buffer, return the buffer it should be remapped into. |
1311 | Buffer GetBufferRemap(Buffer buf, bool allow_alloc = false) { |
1312 | auto key = buf.get(); |
1313 | auto it = buf_map_.find(key); |
1314 | if (it != buf_map_.end()) { |
1315 | return it->second; |
1316 | } |
1317 | |
1318 | ICHECK(allow_alloc) << "Buffer " << buf << " accessed before declaration." ; |
1319 | |
1320 | auto lookup = layout_transforms_.Get(buf); |
1321 | if (lookup) { |
1322 | Array<IndexMap> transforms = lookup.value(); |
1323 | |
1324 | auto write_ptr = buf.CopyOnWrite(); |
1325 | for (const auto& transform : transforms) { |
1326 | write_ptr->shape = transform->MapShape(buf->shape); |
1327 | } |
1328 | } |
1329 | |
1330 | buf_map_[key] = buf; |
1331 | return buf; |
1332 | } |
1333 | |
1334 | std::unordered_map<const BufferNode*, Buffer> buf_map_; |
1335 | |
1336 | Map<Buffer, Array<IndexMap>> layout_transforms_; |
1337 | }; |
1338 | |
1339 | class StorageFlattener : public StmtExprMutator { |
1340 | public: |
1341 | static transform::Pass Pass(int cache_line_size, bool create_bound_attributes) { |
1342 | auto pass_func = [=](PrimFunc func, IRModule m, transform::PassContext ctx) { |
1343 | IRVisitorWithAnalyzer bound_analyzer; |
1344 | |
1345 | bound_analyzer(func->body); |
1346 | |
1347 | auto pass = StorageFlattener(func->buffer_map, cache_line_size, create_bound_attributes, |
1348 | &bound_analyzer); |
1349 | |
1350 | auto fptr = func.CopyOnWrite(); |
1351 | fptr->body = pass(std::move(fptr->body)); |
1352 | // The buffers in func->buffer_map are deliberately left |
1353 | // unflattened, as they are used for validation of user-provided |
1354 | // arguments. The flattened buffers used in the updated |
1355 | // function body alias the argument buffers. |
1356 | return func; |
1357 | }; |
1358 | return transform::CreatePrimFuncPass(pass_func, 0, "tir.StorageFlattener" , {}); |
1359 | } |
1360 | |
1361 | explicit StorageFlattener(const Map<Var, Buffer>& extern_buffer_map, int cache_line_size, |
1362 | bool create_bound_attributes, IRVisitorWithAnalyzer* bound_analyzer) |
1363 | : bound_analyzer_(bound_analyzer), create_bound_attributes_(create_bound_attributes) { |
1364 | for (auto kv : extern_buffer_map) { |
1365 | BufferEntry e; |
1366 | e.buffer = kv.second; |
1367 | e.flattened_buffer = e.buffer.GetFlattenedBuffer(); |
1368 | // TODO(Lunderberg): Move the handling of boolean into a |
1369 | // dedicated pass. |
1370 | |
1371 | // Boolean tensors are backed by a Int8 array. |
1372 | if (e.buffer->dtype == DataType::Bool()) { |
1373 | { |
1374 | auto writer = e.buffer.CopyOnWrite(); |
1375 | writer->dtype = DataType::Int(8); |
1376 | } |
1377 | { |
1378 | auto writer = e.flattened_buffer.CopyOnWrite(); |
1379 | writer->dtype = DataType::Int(8); |
1380 | } |
1381 | } |
1382 | e.external = true; |
1383 | buffer_var_defines_.insert(kv.second->data.get()); |
1384 | buf_map_[kv.second] = e; |
1385 | } |
1386 | cache_line_size_ = cache_line_size; |
1387 | } |
1388 | |
1389 | Stmt VisitStmt_(const StoreNode* op) final { |
1390 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
1391 | } |
1392 | |
1393 | PrimExpr VisitExpr_(const LoadNode* op) final { |
1394 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
1395 | } |
1396 | |
1397 | Stmt VisitStmt_(const AttrStmtNode* op) final { |
1398 | ICHECK_NE(op->attr_key, attr::buffer_dim_align) |
1399 | << "StorageFlattener assumes that all buffers have accurate strides, " |
1400 | << "and all buffer_dim_align annotations are removed. " |
1401 | << "Please run BufferStrideLegalize first." ; |
1402 | |
1403 | ICHECK_NE(op->attr_key, attr::buffer_bind_scope) |
1404 | << "StorageFlattener assumes that all buffer binds have already been applied. " |
1405 | << "Please run BufferBindUnwrapper first." ; |
1406 | |
1407 | if (op->attr_key == attr::double_buffer_scope && op->node->IsInstance<tir::BufferNode>()) { |
1408 | auto buffer = Downcast<tir::Buffer>(op->node); |
1409 | Stmt body = this->VisitStmt(op->body); |
1410 | const auto& entry = GetBufferEntry(buffer); |
1411 | body = AttrStmt(entry.flattened_buffer->data, op->attr_key, op->value, std::move(body)); |
1412 | return body; |
1413 | } |
1414 | return StmtExprMutator::VisitStmt_(op); |
1415 | } |
1416 | |
1417 | Stmt VisitStmt_(const BufferStoreNode* op) final { |
1418 | if (create_bound_attributes_) shape_collector_.clear(); |
1419 | Stmt stmt = StmtExprMutator::VisitStmt_(op); |
1420 | op = stmt.as<BufferStoreNode>(); |
1421 | |
1422 | const BufferEntry& e = GetBufferEntry(op->buffer); |
1423 | |
1424 | // Handle casts from the value's dtype to the dtype of the backing |
1425 | // array. |
1426 | PrimExpr value = op->value; |
1427 | if (value.dtype() == DataType::Bool()) { |
1428 | ICHECK_EQ(e.flattened_buffer->dtype, DataType::Int(8)) |
1429 | << "Expected int8 backing array for boolean tensor, but received " |
1430 | << e.flattened_buffer->dtype; |
1431 | value = tir::Cast(DataType::Int(8), value); |
1432 | } |
1433 | |
1434 | auto flattened_indices = e.buffer->ElemOffset(op->indices); |
1435 | |
1436 | Stmt body = BufferStore(e.flattened_buffer, value, flattened_indices, op->span); |
1437 | if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { |
1438 | shape_collector_.push_back(std::make_pair(e.buffer->data, e.buffer->shape)); |
1439 | } |
1440 | // To create bound attribute collector should has at least one item. |
1441 | if (create_bound_attributes_ && shape_collector_.size()) { |
1442 | for (size_t i = 0; i < shape_collector_.size(); ++i) { |
1443 | body = AttrStmt(shape_collector_[i].first, tir::attr::buffer_bound, |
1444 | MakeBound(e.buffer->dtype, shape_collector_[i].second), body); |
1445 | } |
1446 | } |
1447 | return body; |
1448 | } |
1449 | |
1450 | // AllocateNodes may be present from tvm.tir.ir_builder. This can |
1451 | // be simplified in the future by having AllocateNode hold a buffer, |
1452 | // rather than a buffer_var. |
1453 | Stmt VisitStmt_(const AllocateNode* op) final { |
1454 | buffer_var_defines_.insert(op->buffer_var.get()); |
1455 | auto stmt = Downcast<Allocate>(StmtExprMutator::VisitStmt_(op)); |
1456 | return Allocate(stmt->buffer_var, stmt->dtype, FlattenExtents(stmt), stmt->condition, |
1457 | stmt->body, stmt->annotations, stmt->span); |
1458 | } |
1459 | |
1460 | Stmt VisitStmt_(const AllocateConstNode* op) final { |
1461 | buffer_var_defines_.insert(op->buffer_var.get()); |
1462 | auto stmt = Downcast<AllocateConst>(StmtExprMutator::VisitStmt_(op)); |
1463 | ObjectRef data_or_idx; |
1464 | if (stmt->data) { |
1465 | data_or_idx = stmt->data.value(); |
1466 | } else if (stmt->irmod_storage_idx) { |
1467 | data_or_idx = stmt->irmod_storage_idx.value(); |
1468 | } else { |
1469 | LOG(FATAL) << "Neither data array nor data index specified for allocation of const " |
1470 | << op->buffer_var->name_hint; |
1471 | } |
1472 | return AllocateConst(stmt->buffer_var, stmt->dtype, FlattenExtents(stmt), data_or_idx, |
1473 | stmt->body, stmt->annotations, stmt->span); |
1474 | } |
1475 | |
1476 | Stmt VisitStmt_(const LetStmtNode* op) final { |
1477 | if (op->var.dtype().is_handle()) { |
1478 | buffer_var_defines_.insert(op->var.get()); |
1479 | } |
1480 | return StmtExprMutator::VisitStmt_(op); |
1481 | } |
1482 | |
1483 | PrimExpr VisitExpr_(const LetNode* op) final { |
1484 | if (op->var.dtype().is_handle()) { |
1485 | buffer_var_defines_.insert(op->var.get()); |
1486 | } |
1487 | return StmtExprMutator::VisitExpr_(op); |
1488 | } |
1489 | |
1490 | Stmt VisitStmt_(const BufferRealizeNode* op) final { |
1491 | const auto& key = op->buffer; |
1492 | |
1493 | if (buf_map_.count(key)) { |
1494 | ICHECK(buf_map_.at(key).external) |
1495 | << "BufferRealize for internal buffer " << op->buffer << " appears multiple times." ; |
1496 | return this->VisitStmt(op->body); |
1497 | } else { |
1498 | // create a buffer entry |
1499 | BufferEntry e; |
1500 | |
1501 | ICHECK_EQ(op->buffer->shape.size(), op->bounds.size()) |
1502 | << "Inconsistent buffer shape and realization shape for " << op->buffer; |
1503 | |
1504 | for (size_t i = 0; i < op->bounds.size(); i++) { |
1505 | const auto& bound = op->bounds[i]; |
1506 | const auto& dim_size = op->buffer->shape[i]; |
1507 | ICHECK(is_zero(bound_analyzer_->Simplify(bound->min))) |
1508 | << "Buffer " << op->buffer << " has realization bounds that do not start at zero. " |
1509 | << "Please run BufferShapeLegalize first." ; |
1510 | ICHECK(is_one(bound_analyzer_->Simplify(bound->extent == dim_size))) |
1511 | << "Buffer " << op->buffer |
1512 | << " has realization extent that does not match its size. " |
1513 | "Please run BufferShapeLegalize first." ; |
1514 | } |
1515 | |
1516 | StorageScope skey = StorageScope::Create(GetPtrStorageScope(op->buffer->data)); |
1517 | |
1518 | // use small alignment for small arrays |
1519 | auto dtype = op->buffer->dtype; |
1520 | size_t const_size = AllocateNode::ConstantAllocationSize(op->buffer->shape); |
1521 | int align = GetTempAllocaAlignment(dtype, const_size); |
1522 | if (skey.tag.length() != 0) { |
1523 | MemoryInfo info = GetMemoryInfo(skey.to_string()); |
1524 | if (info.defined()) { |
1525 | align = (info->max_simd_bits + dtype.bits() - 1) / dtype.bits(); |
1526 | ICHECK_LE(const_size * dtype.bits(), info->max_num_bits) |
1527 | << "Allocation exceed bound of memory tag " << skey.to_string(); |
1528 | } |
1529 | } |
1530 | |
1531 | e.buffer = Buffer(op->buffer->data, op->buffer->dtype, op->buffer->shape, op->buffer->strides, |
1532 | PrimExpr(), op->buffer->name, align, 0, kDefault, |
1533 | op->buffer->axis_separators, op->buffer->span); |
1534 | e.flattened_buffer = e.buffer.GetFlattenedBuffer(); |
1535 | |
1536 | // TODO(Lunderberg): Move the handling of boolean into a |
1537 | // dedicated pass. |
1538 | |
1539 | // Boolean tensors are backed by a Int8 array. |
1540 | if (e.flattened_buffer->dtype == DataType::Bool()) { |
1541 | auto writer = e.flattened_buffer.CopyOnWrite(); |
1542 | writer->dtype = DataType::Int(8); |
1543 | } |
1544 | |
1545 | buffer_var_defines_.insert(op->buffer->data.get()); |
1546 | buf_map_[key] = e; |
1547 | Stmt body = this->VisitStmt(op->body); |
1548 | buffer_var_defines_.erase(op->buffer->data.get()); |
1549 | buf_map_[key].in_scope = false; |
1550 | |
1551 | Stmt ret = |
1552 | Allocate(e.flattened_buffer->data, e.flattened_buffer->dtype, e.flattened_buffer->shape, |
1553 | make_const(DataType::Bool(e.flattened_buffer->dtype.lanes()), true), body); |
1554 | |
1555 | if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { |
1556 | ret = AttrStmt(e.buffer->data, tir::attr::buffer_bound, |
1557 | MakeBound(e.buffer->dtype, e.buffer->shape), ret); |
1558 | } |
1559 | return ret; |
1560 | } |
1561 | } |
1562 | |
1563 | PrimExpr VisitExpr_(const VarNode* op) final { |
1564 | auto it = var_remap_.find(op); |
1565 | if (it != var_remap_.end()) { |
1566 | return it->second; |
1567 | } else { |
1568 | return GetRef<PrimExpr>(op); |
1569 | } |
1570 | } |
1571 | |
1572 | PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
1573 | PrimExpr expr = StmtExprMutator::VisitExpr_(op); |
1574 | op = expr.as<BufferLoadNode>(); |
1575 | |
1576 | const BufferEntry& e = GetBufferEntry(op->buffer); |
1577 | |
1578 | if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { |
1579 | shape_collector_.push_back(std::make_pair(e.buffer->data, e.buffer->shape)); |
1580 | } |
1581 | |
1582 | auto flattened_indices = e.buffer->ElemOffset(op->indices); |
1583 | PrimExpr val = BufferLoad(e.flattened_buffer, flattened_indices, op->span); |
1584 | |
1585 | if (op->dtype == DataType::Bool()) { |
1586 | ICHECK_EQ(e.flattened_buffer->dtype, DataType::Int(8)) |
1587 | << "Expected int8 backing array for boolean tensor, but received " |
1588 | << e.flattened_buffer->dtype; |
1589 | val = tir::Cast(DataType::Bool(), val); |
1590 | } |
1591 | |
1592 | return val; |
1593 | } |
1594 | |
1595 | Stmt VisitStmt_(const PrefetchNode* op) final { |
1596 | const BufferEntry& e = GetBufferEntry(op->buffer); |
1597 | |
1598 | ICHECK(e.in_scope) << "Cannot prefetch " << op->buffer << ", out of scope." ; |
1599 | ICHECK_EQ(e.buffer->shape.size(), op->bounds.size()) |
1600 | << "Prefetch dim should be the same as buffer dim" ; |
1601 | |
1602 | int block_size = 1, elem_cnt = cache_line_size_ / e.buffer->dtype.bytes(); |
1603 | |
1604 | int starts = op->bounds.size() - 1; |
1605 | |
1606 | while (starts > 0) { |
1607 | auto* shape_as_int = e.buffer->shape[starts].as<IntImmNode>(); |
1608 | if (shape_as_int == nullptr || block_size * shape_as_int->value > elem_cnt) break; |
1609 | block_size *= static_cast<int>(shape_as_int->value); |
1610 | starts--; |
1611 | } |
1612 | PrimExpr stride(elem_cnt / block_size); |
1613 | |
1614 | Array<PrimExpr> args; |
1615 | std::vector<Var> vars; |
1616 | |
1617 | for (int i = op->bounds.size() - 1; i > starts; --i) { |
1618 | args.push_back(op->bounds[i]->min); |
1619 | } |
1620 | auto& func_name = op->buffer->name; |
1621 | vars.push_back(Var("prefetch." + func_name + "." + std::to_string(starts), DataType::Int(32))); |
1622 | args.push_back(op->bounds[starts]->min + stride * vars.back()); |
1623 | for (int i = starts - 1; i >= 0; --i) { |
1624 | vars.push_back(Var("prefetch." + func_name + "." + std::to_string(i), DataType::Int(32))); |
1625 | args.push_back(vars.back() + op->bounds[i]->min); |
1626 | } |
1627 | |
1628 | Stmt stmt = GetRef<Stmt>(op); |
1629 | for (int i = starts; i >= 0; --i) { |
1630 | if (i < starts) { |
1631 | stmt = For(vars[i], 0, op->bounds[i]->extent, ForKind::kSerial, stmt); |
1632 | } else { |
1633 | PrimExpr load = e.buffer.vload(args, e.buffer->dtype); |
1634 | PrimExpr address = Call(DataType::Handle(), builtin::address_of(), {load}); |
1635 | PrimExpr prefetch = Call(op->buffer->dtype, builtin::prefetch(), {address, 0, 3, 1}); |
1636 | stmt = Evaluate(prefetch); |
1637 | PrimExpr extent = (op->bounds[i]->extent - 1) / stride + 1; |
1638 | stmt = For(vars[i], 0, extent, ForKind::kSerial, stmt); |
1639 | } |
1640 | } |
1641 | return this->VisitStmt(stmt); |
1642 | } |
1643 | |
1644 | PrimExpr VisitExpr_(const ProducerLoadNode* op) final { |
1645 | LOG(FATAL) << "ProducerLoad cannot appear in a valid TIR PrimFunc. " |
1646 | << "Please run SchedulePostProcToPrimFunc first." ; |
1647 | return PrimExpr(); |
1648 | } |
1649 | |
1650 | Stmt VisitStmt_(const ProducerStoreNode* op) final { |
1651 | LOG(FATAL) << "ProducerStore cannot appear in a valid TIR PrimFunc. " |
1652 | << "Please run SchedulePostProcToPrimFunc first." ; |
1653 | return Stmt(); |
1654 | } |
1655 | |
1656 | Stmt VisitStmt_(const ProducerRealizeNode* op) final { |
1657 | LOG(FATAL) << "ProducerRealize cannot appear in a valid TIR PrimFunc. " |
1658 | << "Please run SchedulePostProcToPrimFunc first." ; |
1659 | return Stmt(); |
1660 | } |
1661 | |
1662 | private: |
1663 | // Helper function for visiting Allocate and AllocateConst. If, in |
1664 | // the future, these are updated to hold a buffer (Buffer) object |
1665 | // rather than a buffer_var (Var), this function can be replaced |
1666 | // with a call to GetBufferEntry. |
1667 | template <typename Node> |
1668 | Array<PrimExpr> FlattenExtents(const Node& node) { |
1669 | arith::Analyzer analyzer; |
1670 | |
1671 | // If an allocation has extents that match the buffer |
1672 | auto is_compatible_buffer = [&](const Buffer& buffer) { |
1673 | if (buffer->shape.size() != node->extents.size()) { |
1674 | return false; |
1675 | } |
1676 | for (size_t i = 0; i < buffer->shape.size(); i++) { |
1677 | if (!analyzer.CanProveEqual(buffer->shape[i], node->extents[i])) { |
1678 | return false; |
1679 | } |
1680 | } |
1681 | |
1682 | return true; |
1683 | }; |
1684 | |
1685 | auto int_array_equal = [](const Array<IntImm>& a, const Array<IntImm>& b) { |
1686 | if (a.size() != b.size()) { |
1687 | return false; |
1688 | } |
1689 | |
1690 | for (size_t i = 0; i < a.size(); i++) { |
1691 | if (a[i]->value != b[i]->value) { |
1692 | return false; |
1693 | } |
1694 | } |
1695 | |
1696 | return true; |
1697 | }; |
1698 | |
1699 | Array<IntImm> axis_separators; |
1700 | auto it = buffer_var_map_.find(node->buffer_var.get()); |
1701 | if (it != buffer_var_map_.end()) { |
1702 | const auto& buffers = it->second; |
1703 | if (buffers.size() == 0) { |
1704 | // No buffers use this allocation, treat as flat and optimize |
1705 | // out later. |
1706 | } else if (buffers.size() == 1) { |
1707 | // Only one buffer uses this allocation, so use its axis |
1708 | // separators. |
1709 | axis_separators = buffers[0]->axis_separators; |
1710 | } else { |
1711 | // Try to find a buffer using this allocation with a matching |
1712 | // shape. |
1713 | Buffer compatible_buffer; |
1714 | for (const auto& buffer : buffers) { |
1715 | if (is_compatible_buffer(buffer)) { |
1716 | ICHECK(!compatible_buffer.defined() || |
1717 | int_array_equal(compatible_buffer->axis_separators, buffer->axis_separators)) |
1718 | << "Cannot determine axis separators to use when flattening " |
1719 | << node->buffer_var->name_hint |
1720 | << ", multiple buffer objects found with conflicting axis separators" ; |
1721 | compatible_buffer = buffer; |
1722 | } |
1723 | } |
1724 | ICHECK(compatible_buffer.defined()) |
1725 | << "Cannot determine axis separators to use when flattening " |
1726 | << node->buffer_var->name_hint << ", no buffers found with matching shape" ; |
1727 | axis_separators = compatible_buffer->axis_separators; |
1728 | } |
1729 | } |
1730 | |
1731 | // Use GetFlattenedBuffer to determine the flattened shape of the |
1732 | // output. We only need the shape and axis separators defined, |
1733 | // everything else can be dummy values. |
1734 | Buffer dummy_buffer = |
1735 | decl_buffer(node->extents, DataType::Float(32), "buffer" , "" , axis_separators); |
1736 | return dummy_buffer.GetFlattenedBuffer()->shape; |
1737 | } |
1738 | |
1739 | // The buffer entry in the flatten map |
1740 | struct DimAlignInfo { |
1741 | int align_factor{0}; |
1742 | int align_offset{0}; |
1743 | }; |
1744 | // The buffer entry in the flatten map |
1745 | struct BufferEntry { |
1746 | // The buffer object |
1747 | Buffer buffer; |
1748 | // The updated buffer object, after flattening has been applied. |
1749 | Buffer flattened_buffer; |
1750 | // Whether the buffer is external |
1751 | bool external{false}; |
1752 | // Whether the buffer is currently in scope. |
1753 | bool in_scope{true}; |
1754 | }; |
1755 | |
1756 | bool ShapeIsValid(const Array<PrimExpr>& shape) { |
1757 | // Zero-dimensional tensor does not need boundary check. |
1758 | if (!shape.size()) return false; |
1759 | |
1760 | for (size_t i = 0; i < shape.size(); ++i) { |
1761 | if (!shape[i].defined() || !shape[i].dtype().is_scalar() || is_negative_const(shape[i])) { |
1762 | return false; |
1763 | } |
1764 | } |
1765 | return true; |
1766 | } |
1767 | |
1768 | PrimExpr MakeBound(const DataType& type, const Array<PrimExpr>& shape) { |
1769 | // We have already checked the shape size to be greater then 0. |
1770 | PrimExpr bound = Mul(make_const(shape[0].dtype(), type.lanes()), shape[0]); |
1771 | for (size_t i = 1; i < shape.size(); ++i) { |
1772 | bound = Mul(bound, Mul(make_const(bound.dtype(), type.lanes()), shape[i])); |
1773 | } |
1774 | Array<PrimExpr> bounds{bound}; |
1775 | |
1776 | return Call(DataType::Handle(), builtin::tvm_tuple(), bounds); |
1777 | } |
1778 | |
1779 | const BufferEntry& GetBufferEntry(Buffer buffer) { |
1780 | auto alloc_key = buffer->data.get(); |
1781 | if (!buf_map_.count(buffer) && buffer_var_defines_.count(alloc_key)) { |
1782 | BufferEntry entry; |
1783 | entry.buffer = buffer; |
1784 | entry.flattened_buffer = buffer.GetFlattenedBuffer(); |
1785 | // Boolean tensors are backed by a Int8 array. |
1786 | if (entry.flattened_buffer->dtype == DataType::Bool()) { |
1787 | auto writer = entry.flattened_buffer.CopyOnWrite(); |
1788 | writer->dtype = DataType::Int(8); |
1789 | } |
1790 | buf_map_[buffer] = std::move(entry); |
1791 | } |
1792 | |
1793 | auto it = buf_map_.find(buffer); |
1794 | ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << buffer; |
1795 | const BufferEntry& e = it->second; |
1796 | ICHECK(e.in_scope) << "Cannot access a buffer " << buffer->name << ", out of scope" ; |
1797 | return it->second; |
1798 | } |
1799 | |
1800 | // The buffer assignment map |
1801 | // Variable remap |
1802 | std::unordered_map<const VarNode*, PrimExpr> var_remap_; |
1803 | // Set of vars that have occurred in an AllocateNode, but haven't |
1804 | // yet occurred in a BufferLoad/BufferStore. |
1805 | std::unordered_set<const VarNode*> buffer_var_defines_; |
1806 | // Map from an allocation variable to the buffer(s) that it backs. |
1807 | // Used to track the determine the axis_separators that should be |
1808 | // used for flattening the extents of an AllocateNode. |
1809 | std::unordered_map<const VarNode*, std::vector<Buffer>> buffer_var_map_; |
1810 | // Buffer map |
1811 | std::unordered_map<Buffer, BufferEntry, ObjectPtrHash, ObjectPtrEqual> buf_map_; |
1812 | // Collects shapes. |
1813 | std::vector<std::pair<Var, Array<PrimExpr>>> shape_collector_; |
1814 | // bounds populator. We really need the analyzer from it. |
1815 | // However |
1816 | IRVisitorWithAnalyzer* bound_analyzer_; |
1817 | // The size of cacheline |
1818 | int cache_line_size_; |
1819 | // Whether to mark load/store with theirs bounds. |
1820 | bool create_bound_attributes_{false}; |
1821 | }; |
1822 | |
1823 | /*! |
1824 | * \brief Simplify assert statements. |
1825 | * |
1826 | * If an assert statement can be statically verified to be true, |
1827 | * remove the assert statement. Otherwise, keep the assert statement |
1828 | * unmodified. |
1829 | */ |
1830 | class AssertSimplifier : public StmtMutator { |
1831 | public: |
1832 | static transform::Pass Pass() { |
1833 | auto pass_func = [=](PrimFunc func, IRModule m, transform::PassContext ctx) { |
1834 | IRVisitorWithAnalyzer bound_analyzer; |
1835 | |
1836 | bound_analyzer(func->body); |
1837 | |
1838 | auto fptr = func.CopyOnWrite(); |
1839 | fptr->body = AssertSimplifier(&bound_analyzer)(std::move(fptr->body)); |
1840 | return func; |
1841 | }; |
1842 | return transform::CreatePrimFuncPass(pass_func, 0, "tir.AssertSimplifier" , {}); |
1843 | } |
1844 | |
1845 | explicit AssertSimplifier(IRVisitorWithAnalyzer* bound_analyzer) |
1846 | : bound_analyzer_(bound_analyzer) {} |
1847 | |
1848 | Stmt VisitStmt_(const AssertStmtNode* op) final { |
1849 | Stmt stmt = StmtMutator::VisitStmt_(op); |
1850 | op = stmt.as<AssertStmtNode>(); |
1851 | |
1852 | PrimExpr condition = bound_analyzer_->Simplify(op->condition); |
1853 | if (is_one(condition)) { |
1854 | return op->body; |
1855 | } |
1856 | |
1857 | return stmt; |
1858 | } |
1859 | |
1860 | private: |
1861 | IRVisitorWithAnalyzer* bound_analyzer_; |
1862 | }; |
1863 | |
1864 | // The specific tensor data layout is not determined before |
1865 | // StorageFlatten pass. We use buffer_bind_scope |
1866 | // to specify before hand we want to bind a subregion |
1867 | // of tensor to a symbolic buffer, which get used in extern. |
1868 | // |
1869 | // Example: |
1870 | // |
1871 | // realize A in range [i*4, extent=10) { |
1872 | // bind Ab to A in [i*4+1, extent=4) { |
1873 | // call_func(Ab.ptr, Ab.shape[0]) |
1874 | // } |
1875 | // } |
1876 | // |
1877 | // After StorageFlatten |
1878 | // |
1879 | // alloc A[10] |
1880 | // call(A + 1, 4) |
1881 | // |
1882 | // Buffer is a protocol to declare specific |
1883 | // data layout and shape we expect. |
1884 | // So this function need to check: |
1885 | // - If the bind range is within the realize range |
1886 | // - If we can match the requirement of buffer |
1887 | // - Remap variables such as Ab.ptr to the actual value. |
1888 | // |
1889 | // Here are a few possible failure cases: |
1890 | // - Buffer is declared to have constant shape, |
1891 | // but we try to bind it to a different one. |
1892 | // - Buffer is declared to be compact(no strides) |
1893 | // but this binded region is a subregion of |
1894 | // a matrix(tensor), which means it requires strides. |
1895 | // |
1896 | // We do support a few relaxed case, such as binding a |
1897 | // region with shape [1, 1, n, m] to buffer with shape [n, m] |
1898 | PrimFunc StorageFlatten(PrimFunc func, int cache_line_size, bool create_bound_attributes) { |
1899 | // Only apply this pass to TIR from TE schedules. Because this is a |
1900 | // per-function attribute, we can't just check it once for the |
1901 | // entire module and apply the Sequential transform. |
1902 | Optional<Bool> from_legacy_te_schedule = func->GetAttr("from_legacy_te_schedule" , Bool(false)); |
1903 | if (from_legacy_te_schedule.value()) { |
1904 | auto seq = transform::Sequential( |
1905 | { |
1906 | BufferShapeLegalize::Pass(), |
1907 | BufferStrideLegalize::Pass(), |
1908 | ThreadScopePropagate::Pass(), |
1909 | BufferBindUnwrapper::Pass(), |
1910 | ApplyLayoutTransforms::Pass(), |
1911 | StorageFlattener::Pass(cache_line_size, create_bound_attributes), |
1912 | AssertSimplifier::Pass(), |
1913 | }, |
1914 | "tir.StorageFlatten_impl" ); |
1915 | GlobalVar dummy_func_name("dummy_func" ); |
1916 | IRModule mod(Map<GlobalVar, BaseFunc>({{dummy_func_name, func}})); |
1917 | mod = seq(mod); |
1918 | return Downcast<PrimFunc>(mod->Lookup(dummy_func_name)); |
1919 | } else { |
1920 | return func; |
1921 | } |
1922 | } |
1923 | |
1924 | namespace transform { |
1925 | |
1926 | TVM_REGISTER_GLOBAL("tir.transform.ApplyLayoutTransforms" ) |
1927 | .set_body_typed(ApplyLayoutTransforms::Pass); |
1928 | |
1929 | // TODO(tvm-team): consolidate configs to the PassContext |
1930 | Pass StorageFlatten(int cache_line_size, bool create_bound_attributes) { |
1931 | auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { |
1932 | return StorageFlatten(std::move(f), cache_line_size, create_bound_attributes); |
1933 | }; |
1934 | return CreatePrimFuncPass(pass_func, 0, "tir.StorageFlatten" , {}); |
1935 | } |
1936 | |
1937 | TVM_REGISTER_GLOBAL("tir.transform.StorageFlatten" ).set_body_typed(StorageFlatten); |
1938 | |
1939 | } // namespace transform |
1940 | |
1941 | } // namespace tir |
1942 | } // namespace tvm |
1943 | |