1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file storage_flatten.cc
22 * \brief Flattens storage from multi-dimensional array to 1D buffer access
23 */
24// The pass definition originates from Halide pipeline.
25
26#include <tvm/arith/analyzer.h>
27#include <tvm/runtime/device_api.h>
28#include <tvm/runtime/registry.h>
29#include <tvm/target/target_info.h>
30#include <tvm/te/operation.h>
31#include <tvm/tir/buffer.h>
32#include <tvm/tir/builtin.h>
33#include <tvm/tir/expr.h>
34#include <tvm/tir/op.h>
35#include <tvm/tir/stmt.h>
36#include <tvm/tir/stmt_functor.h>
37#include <tvm/tir/transform.h>
38
39#include <unordered_map>
40#include <unordered_set>
41
42#include "../../arith/ir_visitor_with_analyzer.h"
43#include "../../runtime/thread_storage_scope.h"
44#include "arg_binder.h"
45#include "ir_utils.h"
46
47namespace tvm {
48namespace tir {
49
50using arith::IRVisitorWithAnalyzer;
51using runtime::StorageRank;
52using runtime::StorageScope;
53using runtime::ThreadScope;
54
55/* Make buffer realize extents and buffer shapes consistent
56 *
57 * For external buffers, verify that the extents of BufferRealize
58 * nodes match the shape of the external buffer. For internal
59 * buffers, rewrite the shape of the Buffer objects to match the
60 * extent of the BufferRealize, and rewrite indices of
61 * BufferLoad/BufferStore nodes to match.
62 */
63class BufferShapeLegalize : public StmtExprMutator {
64 public:
65 static transform::Pass Pass() {
66 auto pass_func = [](PrimFunc func, IRModule m, transform::PassContext ctx) {
67 IRVisitorWithAnalyzer bound_analyzer;
68
69 bound_analyzer(func->body);
70
71 auto pass = BufferShapeLegalize(func->buffer_map, &bound_analyzer);
72
73 auto fptr = func.CopyOnWrite();
74 fptr->body = pass(std::move(fptr->body));
75 if (auto map = func->attrs.GetAttr<Map<Buffer, Array<IndexMap>>>("layout_transform_map")) {
76 func = WithAttr(std::move(func), "layout_transform_map", pass.UpdateIndexMap(map.value()));
77 }
78 return func;
79 };
80 return transform::CreatePrimFuncPass(pass_func, 0, "tir.BufferShapeLegalize", {});
81 }
82
83 explicit BufferShapeLegalize(const Map<Var, Buffer>& extern_buffer_map,
84 IRVisitorWithAnalyzer* bound_analyzer)
85 : bound_analyzer_(bound_analyzer) {
86 for (auto kv : extern_buffer_map) {
87 Buffer buf = kv.second;
88 extern_buffers_.insert(buf);
89
90 BufferEntry remap;
91 remap.remap_to = buf;
92 remap.index_offsets = Array<PrimExpr>(buf->shape.size(), 0);
93 remap.in_scope = true;
94 buf_map_[buf] = remap;
95 }
96 }
97
98 Map<Buffer, Array<IndexMap>> UpdateIndexMap(const Map<Buffer, Array<IndexMap>>& orig) {
99 Map<Buffer, Array<IndexMap>> output;
100 for (const auto& kv : orig) {
101 auto it = buf_map_.find(kv.first);
102 if (it != buf_map_.end()) {
103 output.Set(it->second.remap_to, kv.second);
104 } else {
105 output.Set(kv.first, kv.second);
106 }
107 }
108 return output;
109 }
110
111 PrimExpr VisitExpr_(const VarNode* op) final {
112 auto it = var_remap_.find(op);
113 if (it != var_remap_.end()) {
114 return it->second;
115 } else {
116 return GetRef<PrimExpr>(op);
117 }
118 }
119
120 Stmt VisitStmt_(const BufferRealizeNode* op) final {
121 // BufferRealizeNode for an external buffer serves as an
122 // annotation of the external buffers, and should not be changed.
123 // Instead, verify that the bounds match the external
124 // buffer.
125 if (extern_buffers_.count(op->buffer)) {
126 CHECK_EQ(op->buffer->shape.size(), op->bounds.size())
127 << "External buffer realize has mismatched dimension";
128 Stmt stmt = StmtExprMutator::VisitStmt_(op);
129 op = stmt.as<BufferRealizeNode>();
130 ICHECK(op);
131
132 for (size_t i = 0; i < op->bounds.size(); i++) {
133 PrimExpr eq = bound_analyzer_->Simplify(op->buffer->shape[i] == op->bounds[i]->extent);
134 std::ostringstream ss;
135 ss << "Dim " << i << " of external buffer " << op->buffer->name << " has shape "
136 << op->buffer->shape[i] << ", but is only realized for extent " << op->bounds[i]->extent;
137 if (auto eq_int = eq.as<IntImmNode>()) {
138 ICHECK(eq_int->value) << ss.str();
139 } else {
140 stmt = AssertStmt(eq, tvm::tir::StringImm(ss.str()), stmt);
141 }
142 }
143 return stmt;
144 }
145
146 // Compute the new buffer shape, new realization bounds, and the
147 // offsets to be applied to buffer access.
148 Array<PrimExpr> realized_shape;
149 Array<PrimExpr> index_offsets;
150 Array<Range> new_bounds;
151 for (size_t i = 0; i < op->bounds.size(); i++) {
152 const Range& bound = op->bounds[i];
153 realized_shape.push_back(bound->extent);
154 index_offsets.push_back(bound->min);
155 new_bounds.push_back({0, bound->extent});
156 }
157
158 if (op->buffer->shape.size()) {
159 ICHECK_EQ(op->buffer->shape.size(), realized_shape.size())
160 << "Inconsistency between dimension of buffer " << op->buffer
161 << " and dimension of its realized bounds.";
162 }
163
164 Buffer key = op->buffer;
165
166 Buffer buf = op->buffer;
167 auto write_ptr = buf.CopyOnWrite();
168 write_ptr->shape = realized_shape;
169
170 {
171 BufferEntry remap;
172 remap.remap_to = buf;
173 remap.index_offsets = index_offsets;
174 remap.in_scope = true;
175 buf_map_[key] = remap;
176 }
177
178 Stmt stmt = BufferRealize(buf, new_bounds, op->condition, this->VisitStmt(op->body), op->span);
179
180 buf_map_.at(key).in_scope = false;
181
182 return stmt;
183 }
184
185 Stmt VisitStmt_(const BufferStoreNode* op) final {
186 auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
187 return VisitBufferAccess(std::move(node));
188 }
189
190 PrimExpr VisitExpr_(const BufferLoadNode* op) final {
191 auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
192 return VisitBufferAccess(std::move(node));
193 }
194
195 template <typename Node>
196 Node VisitBufferAccess(Node node) {
197 auto it = buf_map_.find(node->buffer);
198 if (it != buf_map_.end()) {
199 const BufferEntry& entry = it->second;
200 ICHECK(entry.in_scope) << "Cannot access an out-of-scope buffer";
201
202 Array<PrimExpr> indices = node->indices;
203 if (entry.index_offsets.size()) {
204 ICHECK_GE(entry.index_offsets.size(), indices.size())
205 << "Cannot bind buffer to a shape of lower dimension.";
206
207 Array<PrimExpr> new_indices;
208
209 // Pad leading indices with zero, matching the "fuzzy_match"
210 // behavior from ArgBinder::BindBuffer.
211 size_t diff = entry.index_offsets.size() - indices.size();
212 for (size_t i = 0; i < diff; i++) {
213 new_indices.push_back(0);
214 }
215
216 // Offset indices used to access buffers of a reduced size.
217 for (size_t i = 0; i < indices.size(); i++) {
218 PrimExpr offset = entry.index_offsets[i + diff];
219 new_indices.push_back(indices[i] - offset);
220 }
221 indices = new_indices;
222 }
223
224 auto write_ptr = node.CopyOnWrite();
225 write_ptr->indices = indices;
226 write_ptr->buffer = entry.remap_to;
227 }
228 return node;
229 }
230
231 Stmt VisitStmt_(const AttrStmtNode* op) final {
232 if (op->node->IsInstance<tir::BufferNode>()) {
233 // Visit body before checking internal_buf_map_, because we
234 // don't know if the BufferNode needs to be changed until we
235 // look in the body for a BufferRealizeNode with different
236 // extents.
237 Stmt body = this->VisitStmt(op->body);
238
239 Buffer buffer = Downcast<tir::Buffer>(op->node);
240 auto it = buf_map_.find(buffer);
241 if (it != buf_map_.end()) {
242 buffer = it->second.remap_to;
243 return AttrStmt(it->second.remap_to, op->attr_key, op->value, body);
244 }
245 return AttrStmt(buffer, op->attr_key, op->value, body);
246
247 } else if (op->attr_key == attr::buffer_bind_scope) {
248 return HandleBufferBindScope(op);
249 }
250
251 return StmtExprMutator::VisitStmt_(op);
252 }
253
254 private:
255 // Any buffers that give views into a resized buffer should be
256 // updated, both to refer to the resized buffer and to have the view
257 // window updated. For example, suppose B1 is a 1-D buffer of size
258 // 100 which is only realized on the range (10,50), and buffer V1 is
259 // a view into B1[25:35]. When B1 is replaced with B2, a buffer of
260 // size 40 realized on the range (0,40), V1 must be replaced to be a
261 // view into B2[15:25].
262 Stmt HandleBufferBindScope(const AttrStmtNode* op) {
263 Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
264 ICHECK_EQ(arr.size(), 2U);
265 Buffer buffer = Downcast<Buffer>(arr[0]);
266 ICHECK(buffer.defined());
267 Buffer target = Downcast<Buffer>(arr[1]);
268 ICHECK(target.defined());
269
270 auto it = buf_map_.find(target);
271 ICHECK(it != buf_map_.end()) << "attr::buffer_bind_scope target " << target << " not in scope.";
272 const BufferEntry& target_remap = it->second;
273
274 ICHECK(target_remap.in_scope) << "Cannot bind " << buffer->name
275 << " to the out-of-scope buffer " << target_remap.remap_to->name;
276
277 Call tuple = Downcast<Call>(op->value);
278 ICHECK(tuple.defined() && tuple->op.same_as(builtin::tvm_tuple()));
279
280 Array<PrimExpr> new_tuple_args;
281 Array<PrimExpr> realized_begins;
282 Array<PrimExpr> view_shape;
283 ICHECK_EQ(tuple->args.size(), target_remap.index_offsets.size() * 2)
284 << "attr::buffer_bind_scope to define " << buffer << " as a view into " << target
285 << " does match dimensionality of " << target;
286 for (size_t i = 0; i < target_remap.index_offsets.size(); i++) {
287 PrimExpr parent_begin = tuple->args[2 * i];
288 PrimExpr view_extent = tuple->args[2 * i + 1];
289 // Offset the begin of the buffer view by the offset of the target buffer.
290 new_tuple_args.push_back(parent_begin - target_remap.index_offsets[i]);
291 // Keep the extent of the buffer view the same.
292 new_tuple_args.push_back(view_extent);
293 // Use the extent of the buffer view to define the buffer view's shape.
294 view_shape.push_back(view_extent);
295 // Within the buffer view, indices start at 0.
296 realized_begins.push_back(0);
297 }
298
299 // If a view is binding to a buffer of a higher dimensionality,
300 // then the leading dimensions should be padded out with shape of
301 // 1.
302 ICHECK_GE(view_shape.size(), buffer->shape.size())
303 << "Cannot bind " << buffer << " to a shape of lower dimension.";
304 if (view_shape.size() > buffer->shape.size()) {
305 size_t diff = view_shape.size() - buffer->shape.size();
306 Array<PrimExpr> padded_shape;
307 for (size_t i = 0; i < diff; i++) {
308 padded_shape.push_back(1);
309 }
310 for (auto dim : buffer->shape) {
311 padded_shape.push_back(dim);
312 }
313 view_shape = std::move(padded_shape);
314 }
315
316 // If a buffer has strides defined, and is being remapped into a
317 // shape with additional dimensions, then define dummy values for
318 // the strides.
319 Array<PrimExpr> realized_strides = buffer->strides;
320 if ((realized_strides.size() > 0) && (realized_strides.size() != view_shape.size())) {
321 ICHECK_GE(view_shape.size(), realized_strides.size())
322 << "Cannot bind the strides of " << buffer << " to a shape of lower dimension";
323 size_t diff = view_shape.size() - buffer->strides.size();
324
325 Array<PrimExpr> updated_strides;
326 for (size_t i = 0; i < diff; i++) {
327 updated_strides.push_back(Var("stride", buffer->shape[0].dtype()));
328 }
329 for (auto stride : buffer->strides) {
330 updated_strides.push_back(stride);
331 }
332 realized_strides = updated_strides;
333 }
334
335 Buffer key = buffer;
336
337 auto write_ptr = buffer.CopyOnWrite();
338 write_ptr->shape = view_shape;
339 write_ptr->strides = realized_strides;
340
341 {
342 BufferEntry remap;
343 remap.index_offsets = realized_begins;
344 remap.remap_to = buffer;
345 remap.in_scope = true;
346 buf_map_[key] = remap;
347 }
348
349 // Define remappings of any Variables referencing Buffer internals
350 // (e.g. Store/Load nodes). Passing fuzzy_match=true allows the
351 // remapped buffer to have a number of dimensions.
352 ArgBinder binder(&var_remap_);
353 binder.BindBuffer(key, buffer, key->name, true);
354
355 Stmt body = this->VisitStmt(op->body);
356 body = MergeNest(binder.asserts(), body);
357 body = MergeNest(binder.init_nest(), body);
358
359 Stmt stmt = AttrStmt(Array<ObjectRef>{buffer, target_remap.remap_to}, op->attr_key,
360 Call(tuple->dtype, tuple->op, new_tuple_args, tuple->span), body);
361
362 for (const Var& v : binder.defs()) {
363 var_remap_.erase(v.get());
364 }
365
366 buf_map_.at(key).in_scope = false;
367 return stmt;
368 }
369
370 std::unordered_map<const VarNode*, PrimExpr> var_remap_;
371
372 std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> extern_buffers_;
373
374 struct BufferEntry {
375 Buffer remap_to;
376 Array<PrimExpr> index_offsets;
377 bool in_scope;
378 };
379
380 std::unordered_map<Buffer, BufferEntry, ObjectPtrHash, ObjectPtrEqual> buf_map_;
381
382 IRVisitorWithAnalyzer* bound_analyzer_;
383};
384
385/* Apply dimension alignment restrictions
386 *
387 * Buffers annotated with attr::buffer_dim_align may need to have
388 * strides defined such that they are no longer in a compact shape.
389 * After this pass, buffers have stride definitions to include these
390 * alignment restrictions, and attr::buffer_dim_align annotations have
391 * been removed.
392 */
393class BufferStrideLegalize : public StmtExprMutator {
394 public:
395 static transform::Pass Pass() {
396 auto pass_func = [](PrimFunc func, IRModule m, transform::PassContext ctx) {
397 IRVisitorWithAnalyzer bound_analyzer;
398
399 bound_analyzer(func->body);
400
401 auto pass = BufferStrideLegalize(func->buffer_map, &bound_analyzer);
402
403 auto fptr = func.CopyOnWrite();
404 fptr->body = pass(std::move(fptr->body));
405 fptr->buffer_map = pass.UpdatedExternBufferMap();
406 if (auto map = func->attrs.GetAttr<Map<Buffer, Array<IndexMap>>>("layout_transform_map")) {
407 func = WithAttr(std::move(func), "layout_transform_map", pass.UpdateIndexMap(map.value()));
408 }
409 return func;
410 };
411 return transform::CreatePrimFuncPass(pass_func, 0, "tir.BufferStrideLegalize", {});
412 }
413
414 explicit BufferStrideLegalize(const Map<Var, Buffer>& extern_buffer_map,
415 IRVisitorWithAnalyzer* bound_analyzer)
416 : bound_analyzer_(bound_analyzer) {
417 for (auto kv : extern_buffer_map) {
418 Buffer buf = kv.second;
419 Buffer with_strides = WithStrides(buf);
420 {
421 BufferEntry entry;
422 entry.remap_to = with_strides;
423 entry.in_scope = true;
424 buf_map_[buf] = entry;
425 }
426 updated_extern_buffer_map_.Set(kv.first, with_strides);
427 }
428 }
429
430 Map<Buffer, Array<IndexMap>> UpdateIndexMap(const Map<Buffer, Array<IndexMap>>& orig) {
431 Map<Buffer, Array<IndexMap>> output;
432 for (const auto& kv : orig) {
433 auto it = buf_map_.find(kv.first);
434 if (it != buf_map_.end()) {
435 output.Set(it->second.remap_to, kv.second);
436 } else {
437 output.Set(kv.first, kv.second);
438 }
439 }
440 return output;
441 }
442
443 Map<Var, Buffer> UpdatedExternBufferMap() const { return updated_extern_buffer_map_; }
444
445 Buffer WithStrides(Buffer buf) {
446 auto cache_key = buf;
447
448 auto it = buf_map_.find(cache_key);
449 if (it != buf_map_.end()) {
450 const BufferEntry& entry = it->second;
451 ICHECK(entry.in_scope) << "Cannot annotate an out-of-scope buffer";
452 return entry.remap_to;
453 }
454
455 Array<PrimExpr> shape = buf->shape;
456
457 if (buf->strides.size()) {
458 ICHECK_EQ(buf->strides.size(), buf->shape.size())
459 << "Buffer " << buf << " has inconsistent strides/shape.";
460 } else if (dim_align_.count(buf) == 0) {
461 // Keeping this to have matched behavior to previous version.
462 // There are many parts of the codebase that assume that a
463 // strided array cannot be compact. For example,
464 // ArgBinder::BindBuffer and tir.Specialize. To avoid breaking
465 // these, do not define the strides unless required for a
466 // non-compact array.
467 } else if (shape.size() == 0) {
468 // Can't define the strides for a buffer without a known shape.
469 } else {
470 // With everything checked, can now define the updated strides
471 std::vector<PrimExpr> rstrides;
472 const std::vector<DimAlignInfo>& avec = dim_align_[buf];
473 int first_dim = 0;
474 PrimExpr stride = make_const(shape[first_dim].dtype(), 1);
475 for (size_t i = shape.size(); i != 0; --i) {
476 size_t dim = i - 1;
477 if (dim < avec.size() && avec[dim].align_factor != 0) {
478 PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor);
479 PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset);
480 stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor);
481 stride = bound_analyzer_->Simplify(stride);
482 }
483 rstrides.push_back(stride);
484 stride = stride * shape[dim];
485 }
486
487 buf.CopyOnWrite()->strides = Array<PrimExpr>(rstrides.rbegin(), rstrides.rend());
488 }
489
490 BufferEntry entry;
491 entry.remap_to = buf;
492 entry.in_scope = true;
493 buf_map_[cache_key] = entry;
494
495 return buf;
496 }
497
498 Stmt VisitStmt_(const AttrStmtNode* op) final {
499 if (op->attr_key == attr::buffer_dim_align) {
500 auto buffer = Downcast<tir::Buffer>(op->node);
501 const CallNode* tuple = op->value.as<CallNode>();
502 ICHECK(tuple && tuple->op.same_as(builtin::tvm_tuple()));
503 auto& vinfo = dim_align_[buffer];
504 int dim = tuple->args[0].as<IntImmNode>()->value;
505 if (static_cast<size_t>(dim) >= vinfo.size()) {
506 vinfo.resize(dim + 1);
507 }
508 vinfo[dim].align_factor = tuple->args[1].as<IntImmNode>()->value;
509 vinfo[dim].align_offset = tuple->args[2].as<IntImmNode>()->value;
510
511 return this->VisitStmt(op->body);
512 } else if (op->attr_key == attr::buffer_bind_scope) {
513 Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
514 ICHECK_EQ(arr.size(), 2U);
515 Buffer source = Downcast<Buffer>(arr[0]);
516 Buffer target_with_strides = WithStrides(Downcast<Buffer>(arr[1]));
517 Buffer source_with_strides = WithStrides(source);
518
519 Stmt body = this->VisitStmt(op->body);
520
521 buf_map_[source].in_scope = false;
522
523 return AttrStmt(Array<ObjectRef>{source_with_strides, target_with_strides}, op->attr_key,
524 op->value, body, op->span);
525 } else {
526 return StmtExprMutator::VisitStmt_(op);
527 }
528 }
529
530 // AllocateNodes may be present from tvm.tir.ir_builder. This can
531 // be simplified in the future by having AllocateNode hold a buffer,
532 // rather than a buffer_var.
533 Stmt VisitStmt_(const AllocateNode* op) final {
534 buffer_var_defines_.insert(op->buffer_var.get());
535 return StmtExprMutator::VisitStmt_(op);
536 }
537
538 Stmt VisitStmt_(const AllocateConstNode* op) final {
539 buffer_var_defines_.insert(op->buffer_var.get());
540 return StmtExprMutator::VisitStmt_(op);
541 }
542
543 Stmt VisitStmt_(const LetStmtNode* op) final {
544 if (op->var.dtype().is_handle()) {
545 buffer_var_defines_.insert(op->var.get());
546 }
547 return StmtExprMutator::VisitStmt_(op);
548 }
549
550 PrimExpr VisitExpr_(const LetNode* op) final {
551 if (op->var.dtype().is_handle()) {
552 buffer_var_defines_.insert(op->var.get());
553 }
554 return StmtExprMutator::VisitExpr_(op);
555 }
556
557 Stmt VisitStmt_(const BufferRealizeNode* op) final {
558 Buffer key = op->buffer;
559 Buffer with_strides = WithStrides(op->buffer);
560
561 Stmt stmt = StmtExprMutator::VisitStmt_(op);
562
563 buf_map_[key].in_scope = false;
564 op = stmt.as<BufferRealizeNode>();
565 ICHECK(op);
566
567 return BufferRealize(with_strides, op->bounds, op->condition, op->body, op->span);
568 }
569
570 Stmt VisitStmt_(const BufferStoreNode* op) final {
571 auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
572 return VisitBufferAccess(std::move(node));
573 }
574
575 PrimExpr VisitExpr_(const BufferLoadNode* op) final {
576 auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
577 return VisitBufferAccess(std::move(node));
578 }
579
580 template <typename Node>
581 Node VisitBufferAccess(Node node) {
582 auto it = buf_map_.find(node->buffer);
583 ICHECK(it == buf_map_.end() || it->second.in_scope)
584 << "Cannot access a buffer " << node->buffer->name << ", out of scope";
585
586 auto with_strides = WithStrides(node->buffer);
587 if (!with_strides.same_as(node->buffer)) {
588 node.CopyOnWrite()->buffer = with_strides;
589 }
590
591 return node;
592 }
593
594 private:
595 Map<Var, Buffer> updated_extern_buffer_map_;
596
597 struct DimAlignInfo {
598 int align_factor{0};
599 int align_offset{0};
600 };
601
602 // Dimension alignment
603 std::unordered_map<Buffer, std::vector<DimAlignInfo>, ObjectPtrHash, ObjectPtrEqual> dim_align_;
604
605 struct BufferEntry {
606 Buffer remap_to;
607 bool in_scope;
608 };
609
610 std::unordered_map<Buffer, BufferEntry, ObjectPtrHash, ObjectPtrEqual> buf_map_;
611
612 // Set of vars that have occurred in an AllocateNode, but haven't
613 // yet occurred in a BufferLoad/BufferStore.
614 std::unordered_set<const VarNode*> buffer_var_defines_;
615
616 IRVisitorWithAnalyzer* bound_analyzer_;
617};
618
619/* Use the scope of IterVar to determine storage scope.
620 *
621 * For buffers that do not have an explicit storage scope defined, a
622 * reasonable storage scope may be defined based on the thread scope
623 * that contains the buffer's allocation. All other buffers without a
624 * scope are assigned to global scope.
625 */
626class ThreadScopePropagate : public StmtExprMutator {
627 public:
628 static transform::Pass Pass() {
629 auto pass_func = [](PrimFunc func, IRModule m, transform::PassContext ctx) {
630 auto pass = ThreadScopePropagate(func->buffer_map);
631
632 auto fptr = func.CopyOnWrite();
633 fptr->body = pass(std::move(fptr->body));
634 if (auto map = func->attrs.GetAttr<Map<Buffer, Array<IndexMap>>>("layout_transform_map")) {
635 func = WithAttr(std::move(func), "layout_transform_map", pass.UpdateIndexMap(map.value()));
636 }
637 return func;
638 };
639 return transform::CreatePrimFuncPass(pass_func, 0, "tir.ThreadScopePropagate", {});
640 }
641
642 explicit ThreadScopePropagate(const Map<Var, Buffer>& extern_buffer_map) {
643 // External buffers shouldn't be overwritten, even if they have a
644 // BufferRealizeNode.
645 for (auto kv : extern_buffer_map) {
646 external_buffers_.insert(kv.second);
647 }
648 }
649
650 Map<Buffer, Array<IndexMap>> UpdateIndexMap(const Map<Buffer, Array<IndexMap>>& orig) {
651 Map<Buffer, Array<IndexMap>> output;
652 for (const auto& kv : orig) {
653 auto it = buf_remap_.find(kv.first->data);
654 if (it != buf_remap_.end()) {
655 output.Set(it->second, kv.second);
656 } else {
657 output.Set(kv.first, kv.second);
658 }
659 }
660 return output;
661 }
662
663 PrimExpr VisitExpr_(const VarNode* op) final {
664 auto it = buf_remap_.find(GetRef<Var>(op));
665 if (it != buf_remap_.end()) {
666 return it->second->data;
667 } else {
668 return GetRef<PrimExpr>(op);
669 }
670 }
671
672 Stmt VisitStmt_(const AttrStmtNode* op) final {
673 ICHECK_NE(op->attr_key, attr::buffer_dim_align)
674 << "StorageFlattener assumes that all buffers have accurate strides, "
675 << "and all buffer_dim_align annotations are removed. "
676 << "Please run BufferStrideLegalize first.";
677
678 if (op->attr_key == attr::thread_extent) {
679 IterVar iv = Downcast<IterVar>(op->node);
680 ThreadScope ts = ThreadScope::Create(iv->thread_tag);
681 curr_thread_scope_.push_back(ts);
682 Stmt stmt = StmtExprMutator::VisitStmt_(op);
683 curr_thread_scope_.pop_back();
684 return stmt;
685 } else if (op->attr_key == attr::buffer_bind_scope) {
686 return HandleBufferBindScope(op);
687 } else {
688 return StmtExprMutator::VisitStmt_(op);
689 }
690 }
691
692 Stmt VisitStmt_(const BufferRealizeNode* op) final {
693 Var old_var = op->buffer->data;
694
695 // Don't remap buffers that already have an explicit scope,
696 // or external buffers.
697 std::string str_scope = GetPtrStorageScope(old_var);
698 if ((str_scope.length() > 0) || external_buffers_.count(op->buffer)) {
699 return StmtExprMutator::VisitStmt_(op);
700 }
701
702 ICHECK_EQ(buf_remap_.count(old_var), 0)
703 << "Buffer var " << op->buffer->data << " appears in multiple BufferRealize nodes";
704
705 StorageScope skey;
706 if (curr_thread_scope_.size() == 0) {
707 skey.rank = StorageRank::kGlobal;
708 } else {
709 skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank);
710 }
711
712 auto ptr_type = old_var->type_annotation.as<PointerTypeNode>();
713 ICHECK(ptr_type);
714 Var new_var(old_var->name_hint, PointerType(ptr_type->element_type, skey.to_string()),
715 old_var->span);
716
717 Buffer buf = op->buffer;
718 buf.CopyOnWrite()->data = new_var;
719
720 buf_remap_[old_var] = buf;
721
722 Stmt body = this->VisitStmt(op->body);
723 return BufferRealize(buf, op->bounds, op->condition, body, op->span);
724 }
725
726 PrimExpr VisitExpr_(const BufferLoadNode* op) final {
727 PrimExpr expr = StmtExprMutator::VisitExpr_(op);
728 op = expr.as<BufferLoadNode>();
729 ICHECK(op);
730
731 auto it = buf_remap_.find(op->buffer->data);
732 if (it != buf_remap_.end()) {
733 return BufferLoad(it->second, op->indices, op->span);
734 } else {
735 return expr;
736 }
737 }
738
739 Stmt VisitStmt_(const BufferStoreNode* op) final {
740 Stmt stmt = StmtExprMutator::VisitStmt_(op);
741 op = stmt.as<BufferStoreNode>();
742 ICHECK(op);
743
744 auto it = buf_remap_.find(op->buffer->data);
745 if (it != buf_remap_.end()) {
746 return BufferStore(it->second, op->value, op->indices, op->span);
747 } else {
748 return stmt;
749 }
750 }
751
752 private:
753 // If the rewritten buffers are part of a buffer_bind_scope, either
754 // as the buffer view or as the buffer being viewed, then the
755 // buffer_bind_scope must be rewritten to refer to the updated
756 // buffers.
757 Stmt HandleBufferBindScope(const AttrStmtNode* op) {
758 Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
759 ICHECK_EQ(arr.size(), 2U);
760 Buffer buffer = Downcast<Buffer>(arr[0]);
761 ICHECK(buffer.defined());
762 Buffer target = Downcast<Buffer>(arr[1]);
763 ICHECK(target.defined());
764
765 bool needs_rewrite = false;
766
767 {
768 auto it = buf_remap_.find(buffer->data);
769 if (it != buf_remap_.end()) {
770 needs_rewrite = true;
771 buffer = it->second;
772 }
773 }
774
775 {
776 auto it = buf_remap_.find(target->data);
777 if (it != buf_remap_.end()) {
778 needs_rewrite = true;
779 target = it->second;
780 }
781 }
782
783 if (needs_rewrite) {
784 Stmt body = this->VisitStmt(op->body);
785 return AttrStmt(Array<ObjectRef>{buffer, target}, op->attr_key, op->value, body);
786 } else {
787 return StmtExprMutator::VisitStmt_(op);
788 }
789 }
790
791 std::unordered_map<Var, Buffer, ObjectPtrHash, ObjectPtrEqual> buf_remap_;
792 std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> external_buffers_;
793
794 // The current thread scope.
795 std::vector<ThreadScope> curr_thread_scope_;
796};
797
798/* Map buffer binds to their source buffer
799 *
800 * Buffers defined using an attr::buffer_bind_scope annotation are
801 * views into some linked buffer, potentially into some restricted
802 * subregion of that buffer. This pass identifies such buffers, then
803 * rewrites all access of the bound buffers to be access into the
804 * linked buffer.
805 */
806class BufferBindUnwrapper : public StmtExprMutator {
807 public:
808 static transform::Pass Pass() {
809 auto pass_func = [](PrimFunc func, IRModule m, transform::PassContext ctx) {
810 IRVisitorWithAnalyzer bound_analyzer;
811
812 bound_analyzer(func->body);
813
814 auto pass = BufferBindUnwrapper(func->buffer_map, &bound_analyzer);
815
816 auto fptr = func.CopyOnWrite();
817 fptr->body = pass(std::move(fptr->body));
818 return func;
819 };
820 return transform::CreatePrimFuncPass(pass_func, 0, "tir.BufferBindUnwrapper", {});
821 }
822
823 explicit BufferBindUnwrapper(const Map<Var, Buffer>& extern_buffer_map,
824 IRVisitorWithAnalyzer* bound_analyzer)
825 : bound_analyzer_(bound_analyzer) {
826 for (auto kv : extern_buffer_map) {
827 BufferEntry e;
828 e.buffer = kv.second;
829 e.external = true;
830 var_to_buffer_[kv.second->data.get()] = kv.second;
831 buf_map_[kv.second.get()] = std::move(e);
832 }
833 }
834
835 Map<Buffer, Array<IndexMap>> UpdateIndexMap(const Map<Buffer, Array<IndexMap>>& orig) {
836 Map<Buffer, Array<IndexMap>> output;
837 for (const auto& kv : orig) {
838 const BufferEntry& e = GetBufferEntry(kv.first);
839
840 if (e.remap) {
841 output.Set(e.remap->target, kv.second);
842 } else {
843 output.Set(kv.first, kv.second);
844 }
845 }
846 return output;
847 }
848
849 Stmt VisitStmt_(const StoreNode* op) final {
850 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
851 }
852
853 PrimExpr VisitExpr_(const LoadNode* op) final {
854 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
855 }
856
857 Stmt VisitStmt_(const AttrStmtNode* op) final {
858 ICHECK_NE(op->attr_key, attr::buffer_dim_align)
859 << "BufferBindUnwrapper assumes that all buffers have accurate strides, "
860 << "and all buffer_dim_align annotations are removed. "
861 << "Please run BufferStrideLegalize first.";
862
863 if (op->attr_key == attr::buffer_bind_scope) {
864 return HandleBufferBindScope(op);
865 } else {
866 return StmtExprMutator::VisitStmt_(op);
867 }
868 }
869
870 PrimExpr VisitExpr_(const VarNode* op) final {
871 ICHECK(!illegal_vars_.count(op)) << "Variable " << op->name_hint << " is not well defined. "
872 << "(e.g. use of buffer.elem_offset for a non-flat buffer)";
873
874 auto it = var_remap_.find(op);
875 if (it != var_remap_.end()) {
876 return it->second;
877 } else {
878 return GetRef<PrimExpr>(op);
879 }
880 }
881
882 Array<PrimExpr> remap_indices(Array<PrimExpr> indices, Array<PrimExpr> begins,
883 Array<PrimExpr> extents) {
884 ICHECK_EQ(begins.size(), extents.size());
885
886 if (begins.size() == 0) {
887 return indices;
888 }
889
890 ICHECK_EQ(begins.size(), indices.size());
891
892 Array<PrimExpr> out;
893 for (size_t i = 0; i < begins.size(); i++) {
894 out.push_back(begins[i] + indices[i]);
895 }
896 return out;
897 }
898
899 Array<Range> remap_bounds(Array<Range> bounds, Array<PrimExpr> begins, Array<PrimExpr> extents) {
900 ICHECK_EQ(begins.size(), extents.size());
901
902 if (begins.size() == 0) {
903 return bounds;
904 }
905
906 ICHECK_EQ(begins.size(), bounds.size());
907
908 Array<Range> out;
909 for (size_t i = 0; i < begins.size(); i++) {
910 out.push_back(Range::FromMinExtent(bounds[i]->min + begins[i], bounds[i]->extent));
911 }
912 return out;
913 }
914
915 // AllocateNodes may be present from tvm.tir.ir_builder. This can
916 // be simplified in the future by having AllocateNode hold a buffer,
917 // rather than a buffer_var.
918 Stmt VisitStmt_(const AllocateNode* op) final {
919 buffer_var_defines_.insert(op->buffer_var.get());
920 return StmtExprMutator::VisitStmt_(op);
921 }
922
923 Stmt VisitStmt_(const AllocateConstNode* op) final {
924 buffer_var_defines_.insert(op->buffer_var.get());
925 return StmtExprMutator::VisitStmt_(op);
926 }
927
928 Stmt VisitStmt_(const LetStmtNode* op) final {
929 if (op->var.dtype().is_handle()) {
930 buffer_var_defines_.insert(op->var.get());
931 }
932 return StmtExprMutator::VisitStmt_(op);
933 }
934
935 PrimExpr VisitExpr_(const LetNode* op) final {
936 if (op->var.dtype().is_handle()) {
937 buffer_var_defines_.insert(op->var.get());
938 }
939 return StmtExprMutator::VisitExpr_(op);
940 }
941
942 PrimExpr VisitExpr_(const BufferLoadNode* op) final {
943 PrimExpr expr = StmtExprMutator::VisitExpr_(op);
944 op = expr.as<BufferLoadNode>();
945
946 const BufferEntry& e = GetBufferEntry(op->buffer);
947
948 if (e.remap) {
949 return BufferLoad(e.remap->target,
950 remap_indices(op->indices, e.remap->begins, e.remap->extents), op->span);
951 } else {
952 return expr;
953 }
954 }
955
956 Stmt VisitStmt_(const BufferStoreNode* op) final {
957 Stmt stmt = StmtExprMutator::VisitStmt_(op);
958 op = stmt.as<BufferStoreNode>();
959
960 const BufferEntry& e = GetBufferEntry(op->buffer);
961
962 if (e.remap) {
963 return BufferStore(e.remap->target, op->value,
964 remap_indices(op->indices, e.remap->begins, e.remap->extents), op->span);
965 } else {
966 return stmt;
967 }
968 }
969
970 Stmt VisitStmt_(const BufferRealizeNode* op) final {
971 const auto& key = op->buffer.get();
972
973 bool is_external = false;
974
975 if (buf_map_.count(key)) {
976 ICHECK(buf_map_.at(key).external)
977 << "BufferRealize node for internal buffer " << op->buffer << " occurred multiple times.";
978
979 is_external = true;
980 } else {
981 BufferEntry e;
982 e.bounds = op->bounds;
983 e.buffer = op->buffer;
984 var_to_buffer_[op->buffer->data.get()] = op->buffer;
985 buf_map_[key] = std::move(e);
986 }
987
988 Stmt stmt = StmtExprMutator::VisitStmt_(op);
989
990 if (is_external) {
991 buf_map_[key].in_scope = false;
992 }
993
994 return stmt;
995 }
996
997 Stmt VisitStmt_(const PrefetchNode* op) final {
998 Stmt stmt = StmtExprMutator::VisitStmt_(op);
999 op = stmt.as<PrefetchNode>();
1000 ICHECK(op != nullptr);
1001
1002 const BufferEntry& e = GetBufferEntry(op->buffer);
1003
1004 ICHECK(e.in_scope) << "Read a buffer that is already out of scope";
1005 ICHECK_EQ(e.buffer->shape.size(), op->bounds.size())
1006 << "Prefetch dim should be the same as buffer dim";
1007
1008 if (e.remap) {
1009 return Prefetch(e.remap->target, remap_bounds(op->bounds, e.remap->begins, e.remap->extents),
1010 op->span);
1011 } else {
1012 return stmt;
1013 }
1014 }
1015
1016 private:
1017 // Read the mapping from a buffer view to the actual buffer. This
1018 // allows all later BufferStore/BufferLoad nodes to reference the
1019 // actual buffer, rather than the buffer view.
1020 Stmt HandleBufferBindScope(const AttrStmtNode* op) {
1021 // Unpack information from Attribute node
1022 RemapInfo remap;
1023
1024 Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
1025 ICHECK_EQ(arr.size(), 2U);
1026 const Buffer source = Downcast<Buffer>(arr[0]);
1027 ICHECK(source.defined());
1028 remap.target = Downcast<Buffer>(arr[1]);
1029 ICHECK(remap.target.defined());
1030 const CallNode* tuple = op->value.as<CallNode>();
1031 ICHECK(tuple && tuple->op.same_as(builtin::tvm_tuple()));
1032
1033 for (size_t i = 0; i < tuple->args.size(); i += 2) {
1034 remap.begins.push_back(tuple->args[i]);
1035 remap.extents.push_back(tuple->args[i + 1]);
1036 }
1037
1038 // Determine bounds in the target buffer
1039 auto it = buf_map_.find(remap.target.get());
1040 ICHECK(it != buf_map_.end()) << "Cannot define " << source << " as a view into " << remap.target
1041 << ", " << remap.target << " was not defined.";
1042 const BufferEntry& target_info = it->second;
1043 ICHECK(target_info.in_scope) << "Cannot define " << source << " as a view into " << remap.target
1044 << ", " << remap.target << " is out of scope.";
1045 ICHECK_EQ(remap.begins.size(), target_info.buffer->shape.size())
1046 << "Incorrect number of arguments in buffer_bind_scope attribute. "
1047 << "Expected (min_0, extent_0, min_1, extent_0, ..., min_N, extent_N).";
1048
1049 if (target_info.bounds.size() > 0) {
1050 Array<PrimExpr> mapped_begins;
1051 for (size_t i = 0; i < target_info.buffer->shape.size(); ++i) {
1052 mapped_begins.push_back(remap.begins[i] - target_info.bounds[i]->min);
1053 }
1054 remap.begins = std::move(mapped_begins);
1055 }
1056
1057 ICHECK(target_info.remap == nullptr)
1058 << "buffer_bind_scope defines " << source << " as a view into " << remap.target
1059 << ", which is itself a buffer view. "
1060 << "Indirect remapping not currently supported.";
1061
1062 for (size_t i = 0; i < remap.begins.size(); i++) {
1063 remap.begins.Set(i, bound_analyzer_->Simplify(remap.begins[i]));
1064 remap.extents.Set(i, bound_analyzer_->Simplify(remap.extents[i]));
1065 }
1066
1067 // Add a buffer remap entry
1068 {
1069 BufferEntry source_info;
1070 source_info.buffer = source;
1071 source_info.remap = std::make_unique<RemapInfo>(remap);
1072
1073 var_to_buffer_[source->data.get()] = source;
1074 buf_map_[source.get()] = std::move(source_info);
1075 }
1076
1077 // Define remappings of any remaining Variables (e.g. Store/Load nodes).
1078 ArgBinder binder(&var_remap_);
1079
1080 // Define a view that represents the source's view into the target
1081 // buffer. This Buffer object is only used to define the mapping
1082 // to the target buffer, and never actually appears in the TIR
1083 // graph.
1084 Buffer view = remap.target.MakeSlice(remap.begins, remap.extents);
1085 if (source->strides.size() == 0) {
1086 ICHECK_EQ(view->strides.size(), 0U)
1087 << "Cannot bind a compact buffer " << source << " to a strided buffer " << view
1088 << " with strides " << view->strides;
1089 } else {
1090 // Add explicit strides to the view, in order to bind to source.strides[i].
1091 view = view.MakeStrideView();
1092 }
1093
1094 // Match integer bits of source->elem_offset and view->elem_offset
1095 // as is required by ArgBinder::Bind_
1096 if (view->elem_offset.defined() && source->elem_offset.dtype() != view->elem_offset.dtype()) {
1097 view.CopyOnWrite()->elem_offset = cast(source->elem_offset.dtype(), view->elem_offset);
1098 }
1099
1100 // Bind any variables that reference the view (e.g. elem_offset,
1101 // strides, shape). Pass fuzzy_match=false, because all shape
1102 // transformations should have been handled in
1103 // BufferShapeLegalize.
1104 binder.BindBuffer(source, view, source->name, false);
1105 if (auto* elem_offset_var = source->elem_offset.as<VarNode>()) {
1106 if (!view->elem_offset.defined()) {
1107 illegal_vars_.insert(elem_offset_var);
1108 }
1109 }
1110
1111 // Apply the remaps
1112 Stmt body = op->body;
1113 body = MergeNest(binder.asserts(), body);
1114 body = MergeNest(binder.init_nest(), body);
1115 body = this->VisitStmt(body);
1116 // remove the binds
1117 for (const Var& v : binder.defs()) {
1118 var_remap_.erase(v.get());
1119 }
1120 return body;
1121 }
1122
1123 struct RemapInfo {
1124 Buffer target;
1125 Array<PrimExpr> begins;
1126 Array<PrimExpr> extents;
1127 };
1128
1129 // The buffer entry in the flatten map
1130 struct BufferEntry {
1131 // The storage buffer
1132 Buffer buffer;
1133 // the bounds of realization, can be null, means everything
1134 Region bounds;
1135 // Whether the buffer is external
1136 bool external{false};
1137 // Whether we are within the allocation scope of the buffer.
1138 bool in_scope{true};
1139
1140 // The buffer to which the storage buffer should be remapped.
1141 std::unique_ptr<RemapInfo> remap{nullptr};
1142 };
1143
1144 const BufferEntry& GetBufferEntry(Buffer buffer) {
1145 if (buf_map_.count(buffer.get())) {
1146 const BufferEntry& e = buf_map_[buffer.get()];
1147 ICHECK(e.in_scope) << "Cannot access a buffer " << buffer->name << ", out of scope";
1148 return e;
1149 } else if (buffer_var_defines_.count(buffer->data.get())) {
1150 // The buffer var was defined, but the buffer hasn't been seen
1151 // before.
1152 BufferEntry entry;
1153 entry.buffer = buffer;
1154 var_to_buffer_[buffer->data.get()] = buffer;
1155 buf_map_[buffer.get()] = std::move(entry);
1156 return buf_map_[buffer.get()];
1157 } else if (var_remap_.count(buffer->data.get())) {
1158 // The buffer var is an alias of a bound buffer. Only
1159 // supported if the bound buffer has no offsets. In this
1160 // case, we just need to make a new aliasing buffer that
1161 // shares the remapped data variable.
1162 Var old_var = buffer->data;
1163 Var new_var = Downcast<Var>(var_remap_[old_var.get()]);
1164
1165 {
1166 ICHECK(var_to_buffer_.count(old_var.get()))
1167 << "Cannot find remap information for aliased buffer var " << old_var->name_hint
1168 << ", required to verify this alias is legal.";
1169 const Buffer& aliased_buffer = var_to_buffer_[old_var.get()];
1170 const BufferEntry& entry = buf_map_[aliased_buffer.get()];
1171 if (entry.remap) {
1172 for (const auto& begin : entry.remap->begins) {
1173 ICHECK(is_zero(begin)) << "Aliasing of buffer with offset is not supported";
1174 }
1175 }
1176 }
1177
1178 {
1179 Buffer new_buf = buffer;
1180 new_buf.CopyOnWrite()->data = new_var;
1181
1182 RemapInfo remap_info;
1183 remap_info.target = new_buf;
1184 remap_info.begins = Array<PrimExpr>(buffer->shape.size(), 0);
1185 remap_info.extents = buffer->shape;
1186
1187 BufferEntry entry;
1188 entry.buffer = buffer;
1189 entry.remap = std::make_unique<RemapInfo>(remap_info);
1190 entry.in_scope = true;
1191 var_to_buffer_[buffer->data.get()] = buffer;
1192 buf_map_[buffer.get()] = std::move(entry);
1193 }
1194 return buf_map_[buffer.get()];
1195 } else if (var_to_buffer_.count(buffer->data.get())) {
1196 // This buffer is an alias of a known buffer, with no remaps. A
1197 // buffer entry should be generated and returned.
1198 BufferEntry entry;
1199 entry.buffer = buffer;
1200 entry.in_scope = true;
1201 var_to_buffer_[buffer->data.get()] = buffer;
1202 buf_map_[buffer.get()] = std::move(entry);
1203
1204 return buf_map_[buffer.get()];
1205 } else {
1206 LOG(FATAL) << "Can't work around the undefined buffer";
1207 }
1208 }
1209
1210 // The buffer assignment map
1211 // Variable remap
1212 std::unordered_map<const VarNode*, PrimExpr> var_remap_;
1213 // Variables that may not occur within the body.
1214 std::unordered_set<const VarNode*> illegal_vars_;
1215 // Buffer map
1216 std::unordered_map<const BufferNode*, BufferEntry> buf_map_;
1217 // Map from Var to the Buffer they occurred in. In case of aliased
1218 // buffers, contains the first buffer.
1219 std::unordered_map<const VarNode*, Buffer> var_to_buffer_;
1220 // Set of vars that have occurred in an AllocateNode, but haven't
1221 // yet occurred in a BufferLoad/BufferStore.
1222 std::unordered_set<const VarNode*> buffer_var_defines_;
1223 // Analyzer for the variable bounds, used to simplify the bounds populator. We really need the
1224 // analyzer from it. However
1225 IRVisitorWithAnalyzer* bound_analyzer_;
1226};
1227
1228class ApplyLayoutTransforms : public StmtExprMutator {
1229 public:
1230 static transform::Pass Pass() {
1231 auto pass_func = [](PrimFunc func, IRModule m, transform::PassContext ctx) {
1232 auto lookup = func->attrs.GetAttr<Map<Buffer, Array<IndexMap>>>("layout_transform_map");
1233
1234 if (!lookup) {
1235 return func;
1236 }
1237
1238 Map<Buffer, Array<IndexMap>> layout_transforms = lookup.value();
1239
1240 auto fptr = func.CopyOnWrite();
1241
1242 auto mutator = ApplyLayoutTransforms(layout_transforms);
1243 fptr->buffer_map = mutator.UpdateExternBufferMap(fptr->buffer_map);
1244 fptr->body = mutator(std::move(fptr->body));
1245
1246 return WithoutAttr(std::move(func), "layout_transform_map");
1247 };
1248 return transform::CreatePrimFuncPass(pass_func, 0, "tir.ApplyLayoutTransforms", {});
1249 }
1250
1251 explicit ApplyLayoutTransforms(Map<Buffer, Array<IndexMap>> layout_transforms)
1252 : layout_transforms_(layout_transforms) {}
1253
1254 Map<tir::Var, Buffer> UpdateExternBufferMap(const Map<tir::Var, Buffer>& buffer_map) {
1255 Map<tir::Var, Buffer> output;
1256 for (const auto& kv : buffer_map) {
1257 output.Set(kv.first, GetBufferRemap(kv.second, true));
1258 }
1259 return output;
1260 }
1261
1262 Stmt VisitStmt_(const BufferRealizeNode* op) final {
1263 // Call once so that load/store nodes can read from the cached
1264 // value.
1265 GetBufferRemap(op->buffer, true);
1266
1267 auto realize = Downcast<BufferRealize>(StmtExprMutator::VisitStmt_(op));
1268
1269 auto lookup = layout_transforms_.Get(op->buffer);
1270 if (lookup) {
1271 auto write_ptr = realize.CopyOnWrite();
1272 write_ptr->buffer = GetBufferRemap(op->buffer, true);
1273
1274 Array<IndexMap> transforms = lookup.value();
1275 for (const auto& transform : transforms) {
1276 write_ptr->bounds = transform->MapRanges(realize->bounds);
1277 }
1278 }
1279
1280 return std::move(realize);
1281 }
1282
1283 Stmt VisitStmt_(const BufferStoreNode* op) final {
1284 auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
1285 return VisitBufferAccess(std::move(node));
1286 }
1287
1288 PrimExpr VisitExpr_(const BufferLoadNode* op) final {
1289 auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
1290 return VisitBufferAccess(std::move(node));
1291 }
1292
1293 template <typename Node>
1294 Node VisitBufferAccess(Node node) {
1295 auto lookup = layout_transforms_.Get(node->buffer);
1296 if (lookup) {
1297 auto write_ptr = node.CopyOnWrite();
1298
1299 write_ptr->buffer = GetBufferRemap(node->buffer);
1300
1301 Array<IndexMap> transforms = lookup.value();
1302 for (const auto& transform : transforms) {
1303 write_ptr->indices = transform->MapIndices(node->indices);
1304 }
1305 }
1306 return node;
1307 }
1308
1309 private:
1310 //! \brief Given a buffer, return the buffer it should be remapped into.
1311 Buffer GetBufferRemap(Buffer buf, bool allow_alloc = false) {
1312 auto key = buf.get();
1313 auto it = buf_map_.find(key);
1314 if (it != buf_map_.end()) {
1315 return it->second;
1316 }
1317
1318 ICHECK(allow_alloc) << "Buffer " << buf << " accessed before declaration.";
1319
1320 auto lookup = layout_transforms_.Get(buf);
1321 if (lookup) {
1322 Array<IndexMap> transforms = lookup.value();
1323
1324 auto write_ptr = buf.CopyOnWrite();
1325 for (const auto& transform : transforms) {
1326 write_ptr->shape = transform->MapShape(buf->shape);
1327 }
1328 }
1329
1330 buf_map_[key] = buf;
1331 return buf;
1332 }
1333
1334 std::unordered_map<const BufferNode*, Buffer> buf_map_;
1335
1336 Map<Buffer, Array<IndexMap>> layout_transforms_;
1337};
1338
1339class StorageFlattener : public StmtExprMutator {
1340 public:
1341 static transform::Pass Pass(int cache_line_size, bool create_bound_attributes) {
1342 auto pass_func = [=](PrimFunc func, IRModule m, transform::PassContext ctx) {
1343 IRVisitorWithAnalyzer bound_analyzer;
1344
1345 bound_analyzer(func->body);
1346
1347 auto pass = StorageFlattener(func->buffer_map, cache_line_size, create_bound_attributes,
1348 &bound_analyzer);
1349
1350 auto fptr = func.CopyOnWrite();
1351 fptr->body = pass(std::move(fptr->body));
1352 // The buffers in func->buffer_map are deliberately left
1353 // unflattened, as they are used for validation of user-provided
1354 // arguments. The flattened buffers used in the updated
1355 // function body alias the argument buffers.
1356 return func;
1357 };
1358 return transform::CreatePrimFuncPass(pass_func, 0, "tir.StorageFlattener", {});
1359 }
1360
1361 explicit StorageFlattener(const Map<Var, Buffer>& extern_buffer_map, int cache_line_size,
1362 bool create_bound_attributes, IRVisitorWithAnalyzer* bound_analyzer)
1363 : bound_analyzer_(bound_analyzer), create_bound_attributes_(create_bound_attributes) {
1364 for (auto kv : extern_buffer_map) {
1365 BufferEntry e;
1366 e.buffer = kv.second;
1367 e.flattened_buffer = e.buffer.GetFlattenedBuffer();
1368 // TODO(Lunderberg): Move the handling of boolean into a
1369 // dedicated pass.
1370
1371 // Boolean tensors are backed by a Int8 array.
1372 if (e.buffer->dtype == DataType::Bool()) {
1373 {
1374 auto writer = e.buffer.CopyOnWrite();
1375 writer->dtype = DataType::Int(8);
1376 }
1377 {
1378 auto writer = e.flattened_buffer.CopyOnWrite();
1379 writer->dtype = DataType::Int(8);
1380 }
1381 }
1382 e.external = true;
1383 buffer_var_defines_.insert(kv.second->data.get());
1384 buf_map_[kv.second] = e;
1385 }
1386 cache_line_size_ = cache_line_size;
1387 }
1388
1389 Stmt VisitStmt_(const StoreNode* op) final {
1390 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
1391 }
1392
1393 PrimExpr VisitExpr_(const LoadNode* op) final {
1394 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
1395 }
1396
1397 Stmt VisitStmt_(const AttrStmtNode* op) final {
1398 ICHECK_NE(op->attr_key, attr::buffer_dim_align)
1399 << "StorageFlattener assumes that all buffers have accurate strides, "
1400 << "and all buffer_dim_align annotations are removed. "
1401 << "Please run BufferStrideLegalize first.";
1402
1403 ICHECK_NE(op->attr_key, attr::buffer_bind_scope)
1404 << "StorageFlattener assumes that all buffer binds have already been applied. "
1405 << "Please run BufferBindUnwrapper first.";
1406
1407 if (op->attr_key == attr::double_buffer_scope && op->node->IsInstance<tir::BufferNode>()) {
1408 auto buffer = Downcast<tir::Buffer>(op->node);
1409 Stmt body = this->VisitStmt(op->body);
1410 const auto& entry = GetBufferEntry(buffer);
1411 body = AttrStmt(entry.flattened_buffer->data, op->attr_key, op->value, std::move(body));
1412 return body;
1413 }
1414 return StmtExprMutator::VisitStmt_(op);
1415 }
1416
1417 Stmt VisitStmt_(const BufferStoreNode* op) final {
1418 if (create_bound_attributes_) shape_collector_.clear();
1419 Stmt stmt = StmtExprMutator::VisitStmt_(op);
1420 op = stmt.as<BufferStoreNode>();
1421
1422 const BufferEntry& e = GetBufferEntry(op->buffer);
1423
1424 // Handle casts from the value's dtype to the dtype of the backing
1425 // array.
1426 PrimExpr value = op->value;
1427 if (value.dtype() == DataType::Bool()) {
1428 ICHECK_EQ(e.flattened_buffer->dtype, DataType::Int(8))
1429 << "Expected int8 backing array for boolean tensor, but received "
1430 << e.flattened_buffer->dtype;
1431 value = tir::Cast(DataType::Int(8), value);
1432 }
1433
1434 auto flattened_indices = e.buffer->ElemOffset(op->indices);
1435
1436 Stmt body = BufferStore(e.flattened_buffer, value, flattened_indices, op->span);
1437 if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) {
1438 shape_collector_.push_back(std::make_pair(e.buffer->data, e.buffer->shape));
1439 }
1440 // To create bound attribute collector should has at least one item.
1441 if (create_bound_attributes_ && shape_collector_.size()) {
1442 for (size_t i = 0; i < shape_collector_.size(); ++i) {
1443 body = AttrStmt(shape_collector_[i].first, tir::attr::buffer_bound,
1444 MakeBound(e.buffer->dtype, shape_collector_[i].second), body);
1445 }
1446 }
1447 return body;
1448 }
1449
1450 // AllocateNodes may be present from tvm.tir.ir_builder. This can
1451 // be simplified in the future by having AllocateNode hold a buffer,
1452 // rather than a buffer_var.
1453 Stmt VisitStmt_(const AllocateNode* op) final {
1454 buffer_var_defines_.insert(op->buffer_var.get());
1455 auto stmt = Downcast<Allocate>(StmtExprMutator::VisitStmt_(op));
1456 return Allocate(stmt->buffer_var, stmt->dtype, FlattenExtents(stmt), stmt->condition,
1457 stmt->body, stmt->annotations, stmt->span);
1458 }
1459
1460 Stmt VisitStmt_(const AllocateConstNode* op) final {
1461 buffer_var_defines_.insert(op->buffer_var.get());
1462 auto stmt = Downcast<AllocateConst>(StmtExprMutator::VisitStmt_(op));
1463 ObjectRef data_or_idx;
1464 if (stmt->data) {
1465 data_or_idx = stmt->data.value();
1466 } else if (stmt->irmod_storage_idx) {
1467 data_or_idx = stmt->irmod_storage_idx.value();
1468 } else {
1469 LOG(FATAL) << "Neither data array nor data index specified for allocation of const "
1470 << op->buffer_var->name_hint;
1471 }
1472 return AllocateConst(stmt->buffer_var, stmt->dtype, FlattenExtents(stmt), data_or_idx,
1473 stmt->body, stmt->annotations, stmt->span);
1474 }
1475
1476 Stmt VisitStmt_(const LetStmtNode* op) final {
1477 if (op->var.dtype().is_handle()) {
1478 buffer_var_defines_.insert(op->var.get());
1479 }
1480 return StmtExprMutator::VisitStmt_(op);
1481 }
1482
1483 PrimExpr VisitExpr_(const LetNode* op) final {
1484 if (op->var.dtype().is_handle()) {
1485 buffer_var_defines_.insert(op->var.get());
1486 }
1487 return StmtExprMutator::VisitExpr_(op);
1488 }
1489
1490 Stmt VisitStmt_(const BufferRealizeNode* op) final {
1491 const auto& key = op->buffer;
1492
1493 if (buf_map_.count(key)) {
1494 ICHECK(buf_map_.at(key).external)
1495 << "BufferRealize for internal buffer " << op->buffer << " appears multiple times.";
1496 return this->VisitStmt(op->body);
1497 } else {
1498 // create a buffer entry
1499 BufferEntry e;
1500
1501 ICHECK_EQ(op->buffer->shape.size(), op->bounds.size())
1502 << "Inconsistent buffer shape and realization shape for " << op->buffer;
1503
1504 for (size_t i = 0; i < op->bounds.size(); i++) {
1505 const auto& bound = op->bounds[i];
1506 const auto& dim_size = op->buffer->shape[i];
1507 ICHECK(is_zero(bound_analyzer_->Simplify(bound->min)))
1508 << "Buffer " << op->buffer << " has realization bounds that do not start at zero. "
1509 << "Please run BufferShapeLegalize first.";
1510 ICHECK(is_one(bound_analyzer_->Simplify(bound->extent == dim_size)))
1511 << "Buffer " << op->buffer
1512 << " has realization extent that does not match its size. "
1513 "Please run BufferShapeLegalize first.";
1514 }
1515
1516 StorageScope skey = StorageScope::Create(GetPtrStorageScope(op->buffer->data));
1517
1518 // use small alignment for small arrays
1519 auto dtype = op->buffer->dtype;
1520 size_t const_size = AllocateNode::ConstantAllocationSize(op->buffer->shape);
1521 int align = GetTempAllocaAlignment(dtype, const_size);
1522 if (skey.tag.length() != 0) {
1523 MemoryInfo info = GetMemoryInfo(skey.to_string());
1524 if (info.defined()) {
1525 align = (info->max_simd_bits + dtype.bits() - 1) / dtype.bits();
1526 ICHECK_LE(const_size * dtype.bits(), info->max_num_bits)
1527 << "Allocation exceed bound of memory tag " << skey.to_string();
1528 }
1529 }
1530
1531 e.buffer = Buffer(op->buffer->data, op->buffer->dtype, op->buffer->shape, op->buffer->strides,
1532 PrimExpr(), op->buffer->name, align, 0, kDefault,
1533 op->buffer->axis_separators, op->buffer->span);
1534 e.flattened_buffer = e.buffer.GetFlattenedBuffer();
1535
1536 // TODO(Lunderberg): Move the handling of boolean into a
1537 // dedicated pass.
1538
1539 // Boolean tensors are backed by a Int8 array.
1540 if (e.flattened_buffer->dtype == DataType::Bool()) {
1541 auto writer = e.flattened_buffer.CopyOnWrite();
1542 writer->dtype = DataType::Int(8);
1543 }
1544
1545 buffer_var_defines_.insert(op->buffer->data.get());
1546 buf_map_[key] = e;
1547 Stmt body = this->VisitStmt(op->body);
1548 buffer_var_defines_.erase(op->buffer->data.get());
1549 buf_map_[key].in_scope = false;
1550
1551 Stmt ret =
1552 Allocate(e.flattened_buffer->data, e.flattened_buffer->dtype, e.flattened_buffer->shape,
1553 make_const(DataType::Bool(e.flattened_buffer->dtype.lanes()), true), body);
1554
1555 if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) {
1556 ret = AttrStmt(e.buffer->data, tir::attr::buffer_bound,
1557 MakeBound(e.buffer->dtype, e.buffer->shape), ret);
1558 }
1559 return ret;
1560 }
1561 }
1562
1563 PrimExpr VisitExpr_(const VarNode* op) final {
1564 auto it = var_remap_.find(op);
1565 if (it != var_remap_.end()) {
1566 return it->second;
1567 } else {
1568 return GetRef<PrimExpr>(op);
1569 }
1570 }
1571
1572 PrimExpr VisitExpr_(const BufferLoadNode* op) final {
1573 PrimExpr expr = StmtExprMutator::VisitExpr_(op);
1574 op = expr.as<BufferLoadNode>();
1575
1576 const BufferEntry& e = GetBufferEntry(op->buffer);
1577
1578 if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) {
1579 shape_collector_.push_back(std::make_pair(e.buffer->data, e.buffer->shape));
1580 }
1581
1582 auto flattened_indices = e.buffer->ElemOffset(op->indices);
1583 PrimExpr val = BufferLoad(e.flattened_buffer, flattened_indices, op->span);
1584
1585 if (op->dtype == DataType::Bool()) {
1586 ICHECK_EQ(e.flattened_buffer->dtype, DataType::Int(8))
1587 << "Expected int8 backing array for boolean tensor, but received "
1588 << e.flattened_buffer->dtype;
1589 val = tir::Cast(DataType::Bool(), val);
1590 }
1591
1592 return val;
1593 }
1594
1595 Stmt VisitStmt_(const PrefetchNode* op) final {
1596 const BufferEntry& e = GetBufferEntry(op->buffer);
1597
1598 ICHECK(e.in_scope) << "Cannot prefetch " << op->buffer << ", out of scope.";
1599 ICHECK_EQ(e.buffer->shape.size(), op->bounds.size())
1600 << "Prefetch dim should be the same as buffer dim";
1601
1602 int block_size = 1, elem_cnt = cache_line_size_ / e.buffer->dtype.bytes();
1603
1604 int starts = op->bounds.size() - 1;
1605
1606 while (starts > 0) {
1607 auto* shape_as_int = e.buffer->shape[starts].as<IntImmNode>();
1608 if (shape_as_int == nullptr || block_size * shape_as_int->value > elem_cnt) break;
1609 block_size *= static_cast<int>(shape_as_int->value);
1610 starts--;
1611 }
1612 PrimExpr stride(elem_cnt / block_size);
1613
1614 Array<PrimExpr> args;
1615 std::vector<Var> vars;
1616
1617 for (int i = op->bounds.size() - 1; i > starts; --i) {
1618 args.push_back(op->bounds[i]->min);
1619 }
1620 auto& func_name = op->buffer->name;
1621 vars.push_back(Var("prefetch." + func_name + "." + std::to_string(starts), DataType::Int(32)));
1622 args.push_back(op->bounds[starts]->min + stride * vars.back());
1623 for (int i = starts - 1; i >= 0; --i) {
1624 vars.push_back(Var("prefetch." + func_name + "." + std::to_string(i), DataType::Int(32)));
1625 args.push_back(vars.back() + op->bounds[i]->min);
1626 }
1627
1628 Stmt stmt = GetRef<Stmt>(op);
1629 for (int i = starts; i >= 0; --i) {
1630 if (i < starts) {
1631 stmt = For(vars[i], 0, op->bounds[i]->extent, ForKind::kSerial, stmt);
1632 } else {
1633 PrimExpr load = e.buffer.vload(args, e.buffer->dtype);
1634 PrimExpr address = Call(DataType::Handle(), builtin::address_of(), {load});
1635 PrimExpr prefetch = Call(op->buffer->dtype, builtin::prefetch(), {address, 0, 3, 1});
1636 stmt = Evaluate(prefetch);
1637 PrimExpr extent = (op->bounds[i]->extent - 1) / stride + 1;
1638 stmt = For(vars[i], 0, extent, ForKind::kSerial, stmt);
1639 }
1640 }
1641 return this->VisitStmt(stmt);
1642 }
1643
1644 PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
1645 LOG(FATAL) << "ProducerLoad cannot appear in a valid TIR PrimFunc. "
1646 << "Please run SchedulePostProcToPrimFunc first.";
1647 return PrimExpr();
1648 }
1649
1650 Stmt VisitStmt_(const ProducerStoreNode* op) final {
1651 LOG(FATAL) << "ProducerStore cannot appear in a valid TIR PrimFunc. "
1652 << "Please run SchedulePostProcToPrimFunc first.";
1653 return Stmt();
1654 }
1655
1656 Stmt VisitStmt_(const ProducerRealizeNode* op) final {
1657 LOG(FATAL) << "ProducerRealize cannot appear in a valid TIR PrimFunc. "
1658 << "Please run SchedulePostProcToPrimFunc first.";
1659 return Stmt();
1660 }
1661
1662 private:
1663 // Helper function for visiting Allocate and AllocateConst. If, in
1664 // the future, these are updated to hold a buffer (Buffer) object
1665 // rather than a buffer_var (Var), this function can be replaced
1666 // with a call to GetBufferEntry.
1667 template <typename Node>
1668 Array<PrimExpr> FlattenExtents(const Node& node) {
1669 arith::Analyzer analyzer;
1670
1671 // If an allocation has extents that match the buffer
1672 auto is_compatible_buffer = [&](const Buffer& buffer) {
1673 if (buffer->shape.size() != node->extents.size()) {
1674 return false;
1675 }
1676 for (size_t i = 0; i < buffer->shape.size(); i++) {
1677 if (!analyzer.CanProveEqual(buffer->shape[i], node->extents[i])) {
1678 return false;
1679 }
1680 }
1681
1682 return true;
1683 };
1684
1685 auto int_array_equal = [](const Array<IntImm>& a, const Array<IntImm>& b) {
1686 if (a.size() != b.size()) {
1687 return false;
1688 }
1689
1690 for (size_t i = 0; i < a.size(); i++) {
1691 if (a[i]->value != b[i]->value) {
1692 return false;
1693 }
1694 }
1695
1696 return true;
1697 };
1698
1699 Array<IntImm> axis_separators;
1700 auto it = buffer_var_map_.find(node->buffer_var.get());
1701 if (it != buffer_var_map_.end()) {
1702 const auto& buffers = it->second;
1703 if (buffers.size() == 0) {
1704 // No buffers use this allocation, treat as flat and optimize
1705 // out later.
1706 } else if (buffers.size() == 1) {
1707 // Only one buffer uses this allocation, so use its axis
1708 // separators.
1709 axis_separators = buffers[0]->axis_separators;
1710 } else {
1711 // Try to find a buffer using this allocation with a matching
1712 // shape.
1713 Buffer compatible_buffer;
1714 for (const auto& buffer : buffers) {
1715 if (is_compatible_buffer(buffer)) {
1716 ICHECK(!compatible_buffer.defined() ||
1717 int_array_equal(compatible_buffer->axis_separators, buffer->axis_separators))
1718 << "Cannot determine axis separators to use when flattening "
1719 << node->buffer_var->name_hint
1720 << ", multiple buffer objects found with conflicting axis separators";
1721 compatible_buffer = buffer;
1722 }
1723 }
1724 ICHECK(compatible_buffer.defined())
1725 << "Cannot determine axis separators to use when flattening "
1726 << node->buffer_var->name_hint << ", no buffers found with matching shape";
1727 axis_separators = compatible_buffer->axis_separators;
1728 }
1729 }
1730
1731 // Use GetFlattenedBuffer to determine the flattened shape of the
1732 // output. We only need the shape and axis separators defined,
1733 // everything else can be dummy values.
1734 Buffer dummy_buffer =
1735 decl_buffer(node->extents, DataType::Float(32), "buffer", "", axis_separators);
1736 return dummy_buffer.GetFlattenedBuffer()->shape;
1737 }
1738
1739 // The buffer entry in the flatten map
1740 struct DimAlignInfo {
1741 int align_factor{0};
1742 int align_offset{0};
1743 };
1744 // The buffer entry in the flatten map
1745 struct BufferEntry {
1746 // The buffer object
1747 Buffer buffer;
1748 // The updated buffer object, after flattening has been applied.
1749 Buffer flattened_buffer;
1750 // Whether the buffer is external
1751 bool external{false};
1752 // Whether the buffer is currently in scope.
1753 bool in_scope{true};
1754 };
1755
1756 bool ShapeIsValid(const Array<PrimExpr>& shape) {
1757 // Zero-dimensional tensor does not need boundary check.
1758 if (!shape.size()) return false;
1759
1760 for (size_t i = 0; i < shape.size(); ++i) {
1761 if (!shape[i].defined() || !shape[i].dtype().is_scalar() || is_negative_const(shape[i])) {
1762 return false;
1763 }
1764 }
1765 return true;
1766 }
1767
1768 PrimExpr MakeBound(const DataType& type, const Array<PrimExpr>& shape) {
1769 // We have already checked the shape size to be greater then 0.
1770 PrimExpr bound = Mul(make_const(shape[0].dtype(), type.lanes()), shape[0]);
1771 for (size_t i = 1; i < shape.size(); ++i) {
1772 bound = Mul(bound, Mul(make_const(bound.dtype(), type.lanes()), shape[i]));
1773 }
1774 Array<PrimExpr> bounds{bound};
1775
1776 return Call(DataType::Handle(), builtin::tvm_tuple(), bounds);
1777 }
1778
1779 const BufferEntry& GetBufferEntry(Buffer buffer) {
1780 auto alloc_key = buffer->data.get();
1781 if (!buf_map_.count(buffer) && buffer_var_defines_.count(alloc_key)) {
1782 BufferEntry entry;
1783 entry.buffer = buffer;
1784 entry.flattened_buffer = buffer.GetFlattenedBuffer();
1785 // Boolean tensors are backed by a Int8 array.
1786 if (entry.flattened_buffer->dtype == DataType::Bool()) {
1787 auto writer = entry.flattened_buffer.CopyOnWrite();
1788 writer->dtype = DataType::Int(8);
1789 }
1790 buf_map_[buffer] = std::move(entry);
1791 }
1792
1793 auto it = buf_map_.find(buffer);
1794 ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << buffer;
1795 const BufferEntry& e = it->second;
1796 ICHECK(e.in_scope) << "Cannot access a buffer " << buffer->name << ", out of scope";
1797 return it->second;
1798 }
1799
1800 // The buffer assignment map
1801 // Variable remap
1802 std::unordered_map<const VarNode*, PrimExpr> var_remap_;
1803 // Set of vars that have occurred in an AllocateNode, but haven't
1804 // yet occurred in a BufferLoad/BufferStore.
1805 std::unordered_set<const VarNode*> buffer_var_defines_;
1806 // Map from an allocation variable to the buffer(s) that it backs.
1807 // Used to track the determine the axis_separators that should be
1808 // used for flattening the extents of an AllocateNode.
1809 std::unordered_map<const VarNode*, std::vector<Buffer>> buffer_var_map_;
1810 // Buffer map
1811 std::unordered_map<Buffer, BufferEntry, ObjectPtrHash, ObjectPtrEqual> buf_map_;
1812 // Collects shapes.
1813 std::vector<std::pair<Var, Array<PrimExpr>>> shape_collector_;
1814 // bounds populator. We really need the analyzer from it.
1815 // However
1816 IRVisitorWithAnalyzer* bound_analyzer_;
1817 // The size of cacheline
1818 int cache_line_size_;
1819 // Whether to mark load/store with theirs bounds.
1820 bool create_bound_attributes_{false};
1821};
1822
1823/*!
1824 * \brief Simplify assert statements.
1825 *
1826 * If an assert statement can be statically verified to be true,
1827 * remove the assert statement. Otherwise, keep the assert statement
1828 * unmodified.
1829 */
1830class AssertSimplifier : public StmtMutator {
1831 public:
1832 static transform::Pass Pass() {
1833 auto pass_func = [=](PrimFunc func, IRModule m, transform::PassContext ctx) {
1834 IRVisitorWithAnalyzer bound_analyzer;
1835
1836 bound_analyzer(func->body);
1837
1838 auto fptr = func.CopyOnWrite();
1839 fptr->body = AssertSimplifier(&bound_analyzer)(std::move(fptr->body));
1840 return func;
1841 };
1842 return transform::CreatePrimFuncPass(pass_func, 0, "tir.AssertSimplifier", {});
1843 }
1844
1845 explicit AssertSimplifier(IRVisitorWithAnalyzer* bound_analyzer)
1846 : bound_analyzer_(bound_analyzer) {}
1847
1848 Stmt VisitStmt_(const AssertStmtNode* op) final {
1849 Stmt stmt = StmtMutator::VisitStmt_(op);
1850 op = stmt.as<AssertStmtNode>();
1851
1852 PrimExpr condition = bound_analyzer_->Simplify(op->condition);
1853 if (is_one(condition)) {
1854 return op->body;
1855 }
1856
1857 return stmt;
1858 }
1859
1860 private:
1861 IRVisitorWithAnalyzer* bound_analyzer_;
1862};
1863
1864// The specific tensor data layout is not determined before
1865// StorageFlatten pass. We use buffer_bind_scope
1866// to specify before hand we want to bind a subregion
1867// of tensor to a symbolic buffer, which get used in extern.
1868//
1869// Example:
1870//
1871// realize A in range [i*4, extent=10) {
1872// bind Ab to A in [i*4+1, extent=4) {
1873// call_func(Ab.ptr, Ab.shape[0])
1874// }
1875// }
1876//
1877// After StorageFlatten
1878//
1879// alloc A[10]
1880// call(A + 1, 4)
1881//
1882// Buffer is a protocol to declare specific
1883// data layout and shape we expect.
1884// So this function need to check:
1885// - If the bind range is within the realize range
1886// - If we can match the requirement of buffer
1887// - Remap variables such as Ab.ptr to the actual value.
1888//
1889// Here are a few possible failure cases:
1890// - Buffer is declared to have constant shape,
1891// but we try to bind it to a different one.
1892// - Buffer is declared to be compact(no strides)
1893// but this binded region is a subregion of
1894// a matrix(tensor), which means it requires strides.
1895//
1896// We do support a few relaxed case, such as binding a
1897// region with shape [1, 1, n, m] to buffer with shape [n, m]
1898PrimFunc StorageFlatten(PrimFunc func, int cache_line_size, bool create_bound_attributes) {
1899 // Only apply this pass to TIR from TE schedules. Because this is a
1900 // per-function attribute, we can't just check it once for the
1901 // entire module and apply the Sequential transform.
1902 Optional<Bool> from_legacy_te_schedule = func->GetAttr("from_legacy_te_schedule", Bool(false));
1903 if (from_legacy_te_schedule.value()) {
1904 auto seq = transform::Sequential(
1905 {
1906 BufferShapeLegalize::Pass(),
1907 BufferStrideLegalize::Pass(),
1908 ThreadScopePropagate::Pass(),
1909 BufferBindUnwrapper::Pass(),
1910 ApplyLayoutTransforms::Pass(),
1911 StorageFlattener::Pass(cache_line_size, create_bound_attributes),
1912 AssertSimplifier::Pass(),
1913 },
1914 "tir.StorageFlatten_impl");
1915 GlobalVar dummy_func_name("dummy_func");
1916 IRModule mod(Map<GlobalVar, BaseFunc>({{dummy_func_name, func}}));
1917 mod = seq(mod);
1918 return Downcast<PrimFunc>(mod->Lookup(dummy_func_name));
1919 } else {
1920 return func;
1921 }
1922}
1923
1924namespace transform {
1925
1926TVM_REGISTER_GLOBAL("tir.transform.ApplyLayoutTransforms")
1927 .set_body_typed(ApplyLayoutTransforms::Pass);
1928
1929// TODO(tvm-team): consolidate configs to the PassContext
1930Pass StorageFlatten(int cache_line_size, bool create_bound_attributes) {
1931 auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
1932 return StorageFlatten(std::move(f), cache_line_size, create_bound_attributes);
1933 };
1934 return CreatePrimFuncPass(pass_func, 0, "tir.StorageFlatten", {});
1935}
1936
1937TVM_REGISTER_GLOBAL("tir.transform.StorageFlatten").set_body_typed(StorageFlatten);
1938
1939} // namespace transform
1940
1941} // namespace tir
1942} // namespace tvm
1943