1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file storage_rewrite.cc
22 * \brief Memory access pattern analysis and optimization.
23 * Re-write data access to enable memory sharing when possible.
24 */
25#include <tvm/arith/analyzer.h>
26#include <tvm/ir/type.h>
27#include <tvm/runtime/registry.h>
28#include <tvm/target/target_info.h>
29#include <tvm/tir/analysis.h>
30#include <tvm/tir/builtin.h>
31#include <tvm/tir/expr.h>
32#include <tvm/tir/stmt_functor.h>
33#include <tvm/tir/transform.h>
34
35#include <map>
36#include <unordered_map>
37#include <unordered_set>
38
39#include "../../runtime/thread_storage_scope.h"
40#include "../ir/buffer_common.h"
41#include "ir_utils.h"
42
43namespace tvm {
44namespace tir {
45
46using runtime::StorageRank;
47using runtime::StorageScope;
48
49// Find a linear pattern of storage access
50// Used for liveness analysis.
51// Composite scopes(loop/thread_launch/IfThen) is represented by two points:
52// before_scope -> scope_body -> after_scope
53//
54// The linear_seq_ stores before_scope and after_scope.
55// The access to the arrays are stored at the after_scope point.
56//
57// Define "scope" as the body of For/thread_launch/IfThenElse
58// This pass tries to detect last point that we need to keep memory
59// alive under the same scope as allocate.
60// The storage need to be kept alive between allocate and last access.
61// The free point is only inserted at the same scope of allocate.
62//
63class LinearAccessPatternFinder final : public StmtExprVisitor {
64 public:
65 /*! \brief record the touch hist of statment. */
66 struct StmtEntry {
67 // The statment
68 const Object* stmt;
69 // The index in the linear_seq_ to point to end of the nested scope.
70 // This is only set to non-zero if stmt is a nested scope.
71 // if offset > 0, means this is the begin, the end entry is current_index + offset
72 // if offset < 0, means this is the end, the begin entry is current_index + offset
73 int64_t scope_pair_offset{0};
74 // The buffer variables this statment touched.
75 std::vector<const VarNode*> touched;
76 };
77 // The scope of each allocation
78 struct AllocEntry {
79 // The physical dimension of the allocation.
80 size_t num_physical_dimensions{0};
81 // scope level
82 size_t level{0};
83 // allocation stmt
84 const AllocateNode* alloc{nullptr};
85 };
86
87 void VisitStmt_(const AllocateNode* op) final {
88 size_t level = scope_.size();
89 const VarNode* buf = op->buffer_var.get();
90
91 AllocEntry entry;
92 entry.alloc = op;
93 entry.level = level;
94 // Since StorageRewrite occurs after StorageFlatten/FlattenBuffer,
95 // all allocations specify the extent of physical dimensions, and
96 // is 1 for flat memory spaces.
97 entry.num_physical_dimensions = op->extents.size();
98 alloc_info_[buf] = entry;
99
100 StmtExprVisitor::VisitStmt_(op);
101 }
102
103 void VisitStmt_(const StoreNode* op) final {
104 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
105 }
106
107 void VisitStmt_(const BufferStoreNode* op) final {
108 scope_.push_back(StmtEntry());
109 // visit subexpr
110 StmtExprVisitor::VisitStmt_(op);
111 // Add write access.
112 const VarNode* buf = op->buffer->data.get();
113 auto it = alloc_info_.find(buf);
114 if (it != alloc_info_.end() && it->second.alloc) {
115 ICHECK_LT(it->second.level, scope_.size());
116 scope_[it->second.level].touched.push_back(buf);
117
118 ICHECK_EQ(op->buffer->axis_separators.size() + 1, it->second.num_physical_dimensions)
119 << "Buffer " << op->buffer->name << " is allocated with "
120 << it->second.num_physical_dimensions
121 << " physical dimensions, but is accessed as having "
122 << op->buffer->axis_separators.size() + 1 << " physical dimensions" << std::endl;
123 }
124 StmtEntry e = scope_.back();
125 scope_.pop_back();
126 if (e.touched.size() != 0) {
127 e.stmt = op;
128 linear_seq_.push_back(e);
129 }
130 }
131
132 void VisitExpr_(const LoadNode* op) final {
133 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
134 }
135
136 void VisitExpr_(const BufferLoadNode* op) final {
137 // Add write access.
138 StmtExprVisitor::VisitExpr_(op);
139 const VarNode* buf = op->buffer->data.get();
140 auto it = alloc_info_.find(buf);
141 if (it != alloc_info_.end() && it->second.alloc) {
142 ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store.";
143 scope_[it->second.level].touched.push_back(buf);
144
145 ICHECK_EQ(op->buffer->axis_separators.size() + 1, it->second.num_physical_dimensions)
146 << "Buffer " << op->buffer->name << " is allocated with "
147 << it->second.num_physical_dimensions
148 << " physical dimensions, but is accessed as having "
149 << op->buffer->axis_separators.size() + 1 << " physical dimensions" << std::endl;
150 }
151 }
152
153 void VisitStmt_(const EvaluateNode* op) final {
154 scope_.push_back(StmtEntry());
155 // visit subexpr
156 StmtExprVisitor::VisitStmt_(op);
157 StmtEntry e = scope_.back();
158 scope_.pop_back();
159 if (e.touched.size() != 0) {
160 e.stmt = op;
161 linear_seq_.push_back(e);
162 }
163 }
164
165 void VisitExpr_(const CallNode* op) final {
166 if (op->op.same_as(builtin::address_of())) {
167 const BufferLoadNode* load = op->args[0].as<BufferLoadNode>();
168 for (const auto& index : load->indices) {
169 this->VisitExpr(index);
170 }
171 } else {
172 StmtExprVisitor::VisitExpr_(op);
173 }
174 }
175
176 void VisitExpr_(const VarNode* buf) final {
177 // Directly reference to the variable count as a read.
178 auto it = alloc_info_.find(buf);
179 if (it != alloc_info_.end() && it->second.alloc) {
180 ICHECK_LT(it->second.level, scope_.size()) << " buf=" << buf->name_hint;
181 scope_[it->second.level].touched.push_back(buf);
182 }
183 }
184
185 template <typename T>
186 void VisitNewScope(const T* op) {
187 scope_.push_back(StmtEntry());
188 StmtEntry e;
189 e.stmt = op;
190 int64_t begin_index = static_cast<int64_t>(linear_seq_.size());
191 // before scope.
192 linear_seq_.push_back(e);
193 StmtExprVisitor::VisitStmt_(op);
194 // after scope.
195 e.touched = std::move(scope_.back().touched);
196 scope_.pop_back();
197 int64_t end_index = static_cast<int64_t>(linear_seq_.size());
198 ICHECK_GT(end_index, begin_index);
199 e.scope_pair_offset = begin_index - end_index;
200 linear_seq_.push_back(e);
201 // record the pointer to end index.
202 ICHECK_NE(end_index, 0U);
203 linear_seq_[begin_index].scope_pair_offset = end_index - begin_index;
204 }
205
206 void VisitStmt_(const AttrStmtNode* op) final {
207 // Only record the outer most thread extent.
208 if (op->attr_key == attr::thread_extent && !in_thread_env_) {
209 in_thread_env_ = true;
210 VisitNewScope(op);
211 in_thread_env_ = false;
212 } else if (op->attr_key == attr::extern_scope) {
213 VisitNewScope(op);
214 } else if (op->attr_key == attr::virtual_thread) {
215 VisitNewScope(op);
216 } else {
217 StmtExprVisitor::VisitStmt_(op);
218 }
219 }
220
221 void VisitStmt_(const IfThenElseNode* op) final { VisitNewScope(op); }
222
223 void VisitStmt_(const ForNode* op) final { VisitNewScope(op); }
224
225 void VisitStmt_(const WhileNode* op) final { VisitNewScope(op); }
226
227 void VisitStmt_(const AssertStmtNode* op) final { VisitNewScope(op); }
228
229 void VisitStmt_(const LetStmtNode* op) final { VisitNewScope(op); }
230
231 // linearized access sequence.
232 std::vector<StmtEntry> linear_seq_;
233 // The storage scope of each buffer
234 std::unordered_map<const VarNode*, AllocEntry> alloc_info_;
235
236 private:
237 // Whether already in thread env.
238 bool in_thread_env_{false};
239 // The scope stack.
240 std::vector<StmtEntry> scope_;
241};
242
243// Verify if the statement can be run safely via inplace fashion
244//
245// Detect pattern: dst[index] = f(src[index])
246//
247// WARNING: the current detection algorithm cannot handle the case
248// when a location in an array is written multiple times
249//
250// For example, the following program will pass the check,
251// but we cannot make A and B to be the same array.
252//
253// A[0] = B[0] + 1
254// A[0] = B[0] + 1
255//
256// The high level code generator needs to ensure that the generated
257// code only write each location of the target array once.
258//
259// This is the case with IR generated by the current compute schedule.
260// We explicitly return false if we find there is an extern block
261// which can be arbitrary IR.
262//
263// Neve-the-less, inplace detector should be used with care in mind.
264// We may also consider introduce a condition checker that checks
265// if every index only visited once for an absolute sufficient condition.
266//
267// The code after inplace transformation is no longer idempotent.
268//
269class InplaceOpVerifier : public StmtExprVisitor {
270 public:
271 bool Check(const Object* stmt, const VarNode* dst, const VarNode* src) {
272 dst_ = dst;
273 src_ = src;
274 result_ = true;
275 if (stmt->IsInstance<AttrStmtNode>()) {
276 VisitStmt_(static_cast<const AttrStmtNode*>(stmt));
277 } else if (stmt->IsInstance<ForNode>()) {
278 VisitStmt_(static_cast<const ForNode*>(stmt));
279 } else if (stmt->IsInstance<IfThenElseNode>()) {
280 VisitStmt_(static_cast<const IfThenElseNode*>(stmt));
281 } else if (stmt->IsInstance<WhileNode>()) {
282 VisitStmt_(static_cast<const WhileNode*>(stmt));
283 } else if (stmt->IsInstance<StoreNode>()) {
284 VisitStmt_(static_cast<const StoreNode*>(stmt));
285 } else if (stmt->IsInstance<BufferStoreNode>()) {
286 VisitStmt_(static_cast<const BufferStoreNode*>(stmt));
287 } else {
288 return false;
289 }
290 return result_;
291 }
292
293 using StmtExprVisitor::VisitStmt_;
294
295 void VisitStmt(const Stmt& n) final {
296 if (!result_) return;
297 StmtExprVisitor::VisitStmt(n);
298 }
299 void VisitExpr(const PrimExpr& n) final {
300 if (!result_) return;
301 StmtExprVisitor::VisitExpr(n);
302 }
303
304 void VisitExpr_(const VarNode* op) final {
305 // assume all opaque access is unsafe
306 if (op == dst_ || op == src_) {
307 result_ = false;
308 return;
309 }
310 }
311
312 void VisitStmt_(const StoreNode* op) final {
313 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
314 }
315
316 void VisitStmt_(const BufferStoreNode* op) final {
317 ++mem_nest_;
318 for (const auto& index : op->indices) {
319 this->VisitExpr(index);
320 }
321 --mem_nest_;
322 if (op->buffer->data.get() == dst_) {
323 store_ = op;
324 this->VisitExpr(op->value);
325 store_ = nullptr;
326 } else {
327 this->VisitExpr(op->value);
328 }
329 }
330
331 void VisitStmt_(const AttrStmtNode* op) final {
332 // always reject extern code
333 if (op->attr_key == attr::extern_scope || op->attr_key == attr::volatile_scope) {
334 result_ = false;
335 return;
336 }
337 StmtExprVisitor::VisitStmt_(op);
338 }
339
340 void VisitExpr_(const LoadNode* op) final {
341 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
342 }
343
344 void VisitExpr_(const BufferLoadNode* op) final {
345 const VarNode* buf = op->buffer->data.get();
346 // cannot read from dst_ (no reduction)
347 if (buf == dst_) {
348 result_ = false;
349 return;
350 }
351 // do not allow indirect memory load
352 if (mem_nest_ != 0) {
353 result_ = false;
354 return;
355 }
356 if (src_ == buf) {
357 if (store_ == nullptr || store_->value.dtype() != op->dtype) {
358 result_ = false;
359 return;
360 }
361 ICHECK_EQ(store_->indices.size(), op->indices.size())
362 << "Store/Load occur to the same buffer " << buf->name_hint
363 << " with differing number of indices";
364 for (size_t i = 0; i < store_->indices.size(); i++) {
365 if (!tir::ExprDeepEqual()(store_->indices[i], op->indices[i])) {
366 result_ = false;
367 return;
368 }
369 }
370 }
371 ++mem_nest_;
372 StmtExprVisitor::VisitExpr_(op);
373 --mem_nest_;
374 }
375
376 private:
377 // result of the check
378 bool result_{true};
379 // destination memory
380 const VarNode* dst_;
381 // source variable
382 const VarNode* src_;
383 // counter of load,
384 // it is not safe to inplace when there is nested load like A[B[i]]
385 int mem_nest_{0};
386 // The current store to be inspected
387 const BufferStoreNode* store_{nullptr};
388};
389
390/* \brief Rewrite and merge memory allocation.
391 *
392 * Using LinearAccessPatternFinder, determines which buffers could share an
393 * allocation. This includes both sequential usage of the same buffer and
394 * merging small allocations at the same scope into a single larger allocation.
395 * The merging of small allocations requires the codegen to cast the resulting
396 * value from the storage type to the output type after access.
397 */
398class StoragePlanRewriter : public StmtExprMutator {
399 public:
400 using StmtEntry = LinearAccessPatternFinder::StmtEntry;
401 using AllocEntry = LinearAccessPatternFinder::AllocEntry;
402
403 Stmt Rewrite(Stmt stmt, bool detect_inplace) {
404 detect_inplace_ = detect_inplace;
405 // plan the rewrite
406 LinearAccessPatternFinder finder;
407 finder(stmt);
408 this->LivenessAnalysis(finder.linear_seq_);
409 this->PlanMemory(finder.linear_seq_, finder.alloc_info_);
410 this->PrepareNewAlloc();
411 // start rewrite
412 stmt = operator()(std::move(stmt));
413 if (attach_map_.count(nullptr)) {
414 return MakeAttach(attach_map_.at(nullptr), stmt);
415 }
416 return stmt;
417 }
418
419 Stmt VisitStmt_(const StoreNode* op) final {
420 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
421 }
422
423 PrimExpr VisitExpr_(const LoadNode* op) final {
424 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
425 }
426
427 template <typename Node>
428 Node VisitBufferAccess(Node node) {
429 auto it = alloc_map_.find(node->buffer->data.get());
430 if (it != alloc_map_.end()) {
431 Buffer buf = RemapBuffer(node->buffer, it->second->alloc_var);
432
433 Array<PrimExpr> indices = node->indices;
434 indices.Set(indices.size() - 1,
435 RemapIndex(node->buffer->dtype, indices[indices.size() - 1], it->second));
436
437 auto writer = node.CopyOnWrite();
438 writer->buffer = buf;
439 writer->indices = indices;
440 }
441 return node;
442 }
443
444 Buffer RemapBuffer(Buffer buf, Var new_backing_array) {
445 auto key = buf.get();
446 auto it = buffer_remap_.find(key);
447 if (it != buffer_remap_.end()) {
448 ICHECK_EQ(it->second->data.get(), new_backing_array.get())
449 << "Cannot remap buffer " << buf->name << " to use backing array "
450 << new_backing_array->name_hint << ", previously used backing array "
451 << it->second->data->name_hint;
452 return it->second;
453 }
454
455 Buffer remapped = Buffer(new_backing_array, buf->dtype, buf->shape, buf->strides,
456 buf->elem_offset, new_backing_array->name_hint, buf->data_alignment,
457 buf->offset_factor, buf->buffer_type, buf->axis_separators, buf->span);
458 buffer_remap_[key] = remapped;
459 return remapped;
460 }
461
462 Stmt VisitStmt_(const BufferStoreNode* op) final {
463 auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
464 return VisitBufferAccess(std::move(node));
465 }
466
467 PrimExpr VisitExpr_(const BufferLoadNode* op) final {
468 auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
469 return VisitBufferAccess(std::move(node));
470 }
471
472 PrimExpr VisitExpr_(const VarNode* op) final {
473 auto it = alloc_map_.find(op);
474 if (it != alloc_map_.end()) {
475 if (it->second->bits_offset != 0) {
476 LOG(WARNING) << "Use a merged buffer variable address, could cause error";
477 }
478 return it->second->alloc_var;
479 } else {
480 return GetRef<PrimExpr>(op);
481 }
482 }
483 PrimExpr VisitExpr_(const CallNode* op) final {
484 if (op->op.same_as(builtin::tvm_access_ptr())) {
485 ICHECK_EQ(op->args.size(), 5U);
486 DataType dtype = op->args[0].dtype();
487 const VarNode* buffer = op->args[1].as<VarNode>();
488 auto it = alloc_map_.find(buffer);
489 if (it == alloc_map_.end()) {
490 return StmtExprMutator::VisitExpr_(op);
491 }
492 const StorageEntry* se = it->second;
493 PrimExpr offset = this->VisitExpr(op->args[2]);
494 PrimExpr extent = this->VisitExpr(op->args[3]);
495 uint64_t elem_bits = dtype.bits() * dtype.lanes();
496 ICHECK_EQ(se->bits_offset % elem_bits, 0U);
497 if (se->bits_offset != 0) {
498 offset = make_const(offset.dtype(), se->bits_offset / elem_bits) + offset;
499 }
500 return Call(op->dtype, op->op, {op->args[0], se->alloc_var, offset, extent, op->args[4]});
501 } else {
502 return StmtExprMutator::VisitExpr_(op);
503 }
504 }
505
506 Stmt VisitStmt_(const AttrStmtNode* op) final {
507 if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread ||
508 attr::IsPragmaKey(op->attr_key)) {
509 // remake all the allocation at the attach scope.
510 if (attach_map_.count(op)) {
511 auto& svec = attach_map_[op];
512 Stmt stmt = StmtExprMutator::VisitStmt_(op);
513 op = stmt.as<AttrStmtNode>();
514 return AttrStmt(op->node, op->attr_key, op->value, MakeAttach(svec, op->body));
515 } else {
516 return StmtExprMutator::VisitStmt_(op);
517 }
518 } else if (op->attr_key == attr::volatile_scope) {
519 Stmt stmt = StmtExprMutator::VisitStmt_(op);
520 op = stmt.as<AttrStmtNode>();
521 auto it = alloc_map_.find(op->node.as<VarNode>());
522 if (it == alloc_map_.end()) return stmt;
523 return AttrStmt(it->second->alloc_var, op->attr_key, op->value, op->body);
524 } else {
525 return StmtExprMutator::VisitStmt_(op);
526 }
527 }
528
529 Stmt VisitStmt_(const ForNode* op) final {
530 ICHECK(op->kind != ForKind::kVectorized) << "VectorizeLoop before LiftStorageAlloc";
531 // remake all the allocation at the attach scope.
532 if (attach_map_.count(op)) {
533 auto& svec = attach_map_[op];
534 Stmt stmt = StmtExprMutator::VisitStmt_(op);
535 op = stmt.as<ForNode>();
536 return For(op->loop_var, op->min, op->extent, op->kind, MakeAttach(svec, op->body),
537 op->thread_binding, op->annotations);
538 } else {
539 return StmtExprMutator::VisitStmt_(op);
540 }
541 }
542
543 Stmt VisitStmt_(const AllocateNode* op) final { return this->VisitStmt(op->body); }
544
545 private:
546 struct StorageEntry {
547 // The scope that this alloc attaches after
548 // For shared/local memory it is beginning of the thread extent.
549 // for global memory it is nullptr, means beginning of everything.
550 const Object* attach_scope_{nullptr};
551 // The constant size of the buffer in bits, only used if it is constant
552 uint64_t const_nbits{0};
553 // The storage scope.
554 StorageScope scope;
555 // The physical dimensionality of the allocations. Since
556 // StorageRewrite is applied after StorageFlatten/FlattenBuffer,
557 // this is size of `AllocateNode::extents`. If moved
558 size_t ndim;
559 // Allocs that shares this entry.
560 std::vector<const AllocateNode*> allocs;
561 // The children of this entry, not including itself.
562 std::vector<StorageEntry*> merged_children;
563 // The replacement allocation, if any.
564 Stmt new_alloc;
565 // The var expr of new allocation.
566 Var alloc_var;
567 // The allocation element type.
568 DataType elem_type;
569 // This is non-zero if this allocate is folded into another one
570 // the address(in bits) becomes alloc_var + bits_offset;
571 // can be effectively converted to the element type.
572 // We need to convert bit_offset to offset of specific element type later.
573 //
574 // We use bits(instead of bytes) to support non-conventional indexing in hardware.
575 // When we are merging buffer together, the bits_offset are set to be aligned
576 // to certain value given by the max_simd_bits property of the special memory.
577 //
578 // This allows effective sharing among different types as long as their alignment
579 // requirement fits into the max_simd_bits.
580 uint64_t bits_offset{0};
581 };
582
583 // Checks whether the storage_scope is especially tagged for a specific memory.
584 // Special memory is all combined into a single allocation.
585 bool IsSpecialTaggedMemory(const StorageScope& scope) {
586 return scope.tag.length() != 0 && scope.tag != ".dyn" && scope.tag != ".workspace" &&
587 scope.tag != ".vtcm";
588 }
589
590 // Alllocate entry of node.
591 // Event entry in liveness analysis
592 struct EventEntry {
593 // variables we generate
594 std::vector<const VarNode*> gen;
595 // variables we kill
596 std::vector<const VarNode*> kill;
597 };
598
599 Stmt MakeAttach(const std::vector<StorageEntry*>& svec, Stmt body) {
600 std::vector<Stmt> nest;
601 for (StorageEntry* e : svec) {
602 if (e->new_alloc.defined()) {
603 nest.push_back(e->new_alloc);
604 }
605 }
606 return MergeNest(nest, body);
607 }
608 // Remap the index
609 PrimExpr RemapIndex(DataType dtype, PrimExpr index, StorageEntry* e) {
610 if (e->bits_offset == 0) return index;
611 uint64_t elem_bits = dtype.bits();
612 ICHECK_EQ(e->bits_offset % elem_bits, 0U);
613 return make_const(index.dtype(), e->bits_offset / elem_bits) + index;
614 }
615 // Prepare the new allocations
616 void PrepareNewAlloc() {
617 for (size_t i = 0; i < alloc_vec_.size(); ++i) {
618 StorageEntry* e = alloc_vec_[i].get();
619 attach_map_[e->attach_scope_].push_back(e);
620 }
621 // find allocation via attach map.
622 for (auto& kv : attach_map_) {
623 // find the element with the most amount of bytes.
624 std::vector<StorageEntry*>& vec = kv.second;
625 // try to find merge, for tagged memory
626 for (size_t i = 0; i < vec.size(); ++i) {
627 StorageEntry* e = vec[i];
628 if (IsSpecialTaggedMemory(e->scope)) {
629 ICHECK_NE(e->const_nbits, 0U) << "Special tagged memory must be const size";
630 for (size_t j = 0; j < i; ++j) {
631 if (e->scope == vec[j]->scope) {
632 vec[j]->merged_children.push_back(e);
633 break;
634 }
635 }
636 }
637 }
638 // Start allocation
639 for (size_t i = 0; i < vec.size(); ++i) {
640 StorageEntry* e = vec[i];
641 // already merged
642 if (e->bits_offset != 0) continue;
643 if (e->merged_children.size() != 0) {
644 NewAllocTagMerged(e);
645 continue;
646 }
647 // Get the allocation size;
648 e->alloc_var = e->allocs[0]->buffer_var;
649 DataType alloc_type = e->allocs[0]->dtype;
650 for (const AllocateNode* op : e->allocs) {
651 if (op->dtype.lanes() > alloc_type.lanes()) {
652 alloc_type = op->dtype;
653 }
654 }
655
656 bool all_allocs_identical = std::all_of(
657 e->allocs.begin() + 1, e->allocs.end(), [&](const AllocateNode* op) -> bool {
658 const AllocateNode* first = *e->allocs.begin();
659 if (op->dtype != first->dtype) {
660 return false;
661 }
662 if (op->extents.size() != first->extents.size()) {
663 return false;
664 }
665 ExprDeepEqual expr_equal;
666 for (size_t i = 0; i < op->extents.size(); i++) {
667 if (!expr_equal(op->extents[i], first->extents[i])) {
668 return false;
669 }
670 }
671 return true;
672 });
673
674 if (all_allocs_identical) {
675 // simply use the original allocation.
676 e->new_alloc = Allocate(e->alloc_var, alloc_type, e->allocs[0]->extents,
677 e->allocs[0]->condition, Evaluate(0));
678 if (IsSpecialTaggedMemory(e->scope)) {
679 MemoryInfo info = GetMemoryInfo(e->scope.to_string());
680 if (info.defined()) {
681 uint64_t total_elem = e->const_nbits / e->elem_type.bits();
682 ICHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits)
683 << "Allocation exceed bound of memory tag " << e->scope.to_string();
684 }
685 }
686 } else {
687 // Build a merged allocation
688 PrimExpr combo_size;
689 for (const AllocateNode* op : e->allocs) {
690 ICHECK_EQ(op->extents.size(), 1)
691 << "Buffer var " << op->buffer_var->name_hint
692 << " was identified as a re-usable allocation, but has " << op->extents.size()
693 << " physical dimensions. "
694 << "Currently, only flat 1-d memory spaces should be identified as re-usable "
695 "allocations.";
696 PrimExpr sz = op->extents[0];
697 auto nbits = op->dtype.bits() * op->dtype.lanes();
698 if (const auto* imm = sz.as<IntImmNode>()) {
699 if (imm->value > std::numeric_limits<int>::max() / nbits) {
700 LOG(WARNING) << "The allocation requires : " << imm->value << " * " << nbits
701 << " bits, which is greater than the maximum of"
702 " int32. The size is cast to int64."
703 << "\n";
704 sz = make_const(DataType::Int(64), imm->value);
705 }
706 }
707 // transform to bits
708 auto sz_nbits = sz * nbits;
709 if (combo_size.defined()) {
710 combo_size = max(combo_size, sz_nbits);
711 } else {
712 combo_size = sz_nbits;
713 }
714 }
715 // transform to alloc bytes
716 auto type_bits = alloc_type.bits() * alloc_type.lanes();
717 bool divided = analyzer_.CanProve(indexmod(combo_size, type_bits) == 0);
718 combo_size = indexdiv(combo_size, type_bits);
719 // round up for can not divided
720 if (!divided) {
721 combo_size = combo_size + make_const(DataType::Int(32), 1);
722 }
723 combo_size = analyzer_.Simplify(combo_size);
724 e->new_alloc =
725 Allocate(e->alloc_var, alloc_type, {combo_size}, const_true(), Evaluate(0));
726 if (IsSpecialTaggedMemory(e->scope)) {
727 MemoryInfo info = GetMemoryInfo(e->scope.to_string());
728 if (info.defined()) {
729 uint64_t total_elem = e->const_nbits / e->elem_type.bits();
730 ICHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits)
731 << "Allocation exceed bound of memory tag " << e->scope.to_string();
732 }
733 }
734 }
735 }
736 }
737 }
738 // New allocation for merged data
739 void NewAllocTagMerged(StorageEntry* e) {
740 ICHECK_NE(e->scope.tag.length(), 0U);
741 // allocate with element type.
742 ICHECK_NE(e->const_nbits, 0U);
743 MemoryInfo info = GetMemoryInfo(e->scope.to_string());
744 uint64_t total_bits = e->const_nbits;
745 // By default, align to 32 bits.
746 size_t align = 32;
747 if (info.defined()) {
748 align = info->max_simd_bits;
749 }
750 // Always align to max_simd_bits
751 // so we can remap types by keeping this property
752 if (total_bits % align != 0) {
753 total_bits += align - (total_bits % align);
754 }
755 e->alloc_var = e->allocs[0]->buffer_var;
756 for (StorageEntry* child : e->merged_children) {
757 ICHECK_NE(child->const_nbits, 0U);
758 ICHECK_NE(total_bits, 0U);
759 child->bits_offset = total_bits;
760 child->alloc_var = e->alloc_var;
761 total_bits += child->const_nbits;
762 if (total_bits % align != 0) {
763 total_bits += align - (total_bits % align);
764 }
765 }
766 uint64_t type_bits = e->elem_type.bits() * e->elem_type.lanes();
767 PrimExpr alloc_size =
768 make_const(e->allocs[0]->extents[0].dtype(), (total_bits + type_bits - 1) / type_bits);
769 e->new_alloc = Allocate(e->alloc_var, e->elem_type, {alloc_size}, const_true(), Evaluate(0));
770 if (info.defined()) {
771 ICHECK_LE(total_bits, info->max_num_bits)
772 << "Allocation exceed bound of memory tag " << e->scope.to_string();
773 }
774 }
775 // Liveness analysis to find gen and kill point of each variable.
776 void LivenessAnalysis(const std::vector<StmtEntry>& seq) {
777 // find kill point, do a reverse linear scan.
778 std::unordered_set<const VarNode*> touched;
779 for (size_t i = seq.size(); i != 0; --i) {
780 const StmtEntry& s = seq[i - 1];
781 for (const VarNode* buffer : s.touched) {
782 if (!touched.count(buffer)) {
783 touched.insert(buffer);
784 event_map_[s.stmt].kill.push_back(buffer);
785 }
786 }
787 }
788 // find gen point, do forward scan
789 touched.clear();
790 for (size_t i = 0; i < seq.size(); ++i) {
791 int64_t offset = seq[i].scope_pair_offset;
792 if (offset < 0) continue;
793 const StmtEntry& s = seq[i + offset];
794 for (const VarNode* buffer : s.touched) {
795 if (!touched.count(buffer)) {
796 touched.insert(buffer);
797 event_map_[s.stmt].gen.push_back(buffer);
798 }
799 }
800 }
801 }
802 void PlanNewScope(const Object* op) {
803 if (thread_scope_ != nullptr) {
804 ICHECK(thread_scope_ == op);
805 // erase all memory atatched to this scope.
806 for (auto it = const_free_map_.begin(); it != const_free_map_.end();) {
807 if (it->second->attach_scope_ == op) {
808 it = const_free_map_.erase(it);
809 } else {
810 ++it;
811 }
812 }
813 for (auto it = sym_free_list_.begin(); it != sym_free_list_.end();) {
814 if ((*it)->attach_scope_ == op) {
815 it = sym_free_list_.erase(it);
816 } else {
817 ++it;
818 }
819 }
820 thread_scope_ = nullptr;
821 } else {
822 thread_scope_ = op;
823 }
824 }
825
826 // Memory plan algorithm
827 void PlanMemory(const std::vector<StmtEntry>& seq,
828 const std::unordered_map<const VarNode*, AllocEntry>& alloc_info) {
829 std::unordered_set<const VarNode*> inplace_flag;
830
831 for (size_t i = 0; i < seq.size(); ++i) {
832 const StmtEntry& s = seq[i];
833 auto it = event_map_.find(seq[i].stmt);
834
835 // scope_pair_offset >= 0 means it is either
836 // - leaf stmt(offset = 0)
837 // - beginning of scope(offset < 0)
838 // In both cases, we need to handle the gen event correctly
839 if (it != event_map_.end() && seq[i].scope_pair_offset >= 0) {
840 // Inplace operation detection
841 // specially handle this
842 bool detect_inplace = detect_inplace_ && (it->second.gen.size() <= 2);
843
844 for (const VarNode* var : it->second.gen) {
845 ICHECK(alloc_info.count(var));
846 const AllocEntry& entry = alloc_info.at(var);
847 const AllocateNode* alloc = entry.alloc;
848 auto storage_scope = StorageScope::Create(GetPtrStorageScope(GetRef<Var>(var)));
849 StorageEntry* dst_entry = nullptr;
850 // inplace detection
851 if (detect_inplace) {
852 // only one inplace var for s.stmt
853 bool inplace_found = false;
854 for (const VarNode* src : it->second.kill) {
855 if (!inplace_flag.count(src) && alloc_map_.count(src)) {
856 InplaceOpVerifier visitor;
857 StorageEntry* src_entry = alloc_map_.at(src);
858 if (src_entry->scope == storage_scope &&
859 src_entry->attach_scope_ == thread_scope_ &&
860 src_entry->elem_type == alloc->dtype.element_of() &&
861 visitor.Check(s.stmt, var, src)) {
862 uint64_t const_nbits = static_cast<uint64_t>(alloc->ConstantAllocationSize()) *
863 alloc->dtype.bits() * alloc->dtype.lanes();
864 if (src_entry->const_nbits == const_nbits && !inplace_found) {
865 // successfully inplace
866 dst_entry = src_entry;
867 inplace_flag.insert(src);
868 inplace_found = true;
869 }
870 }
871 }
872 }
873 }
874 if (dst_entry == nullptr) {
875 dst_entry =
876 FindAlloc(alloc, thread_scope_, storage_scope, entry.num_physical_dimensions);
877 }
878 dst_entry->allocs.emplace_back(alloc);
879 alloc_map_[var] = dst_entry;
880 }
881 }
882 // enter/exit new scope
883 if (s.stmt->IsInstance<AttrStmtNode>()) {
884 const auto* op = static_cast<const AttrStmtNode*>(s.stmt);
885 if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread ||
886 attr::IsPragmaKey(op->attr_key)) {
887 PlanNewScope(op);
888 } else {
889 ICHECK(op->attr_key == attr::extern_scope);
890 }
891 } else if (s.stmt->IsInstance<ForNode>()) {
892 const auto* op = static_cast<const ForNode*>(s.stmt);
893 if (op->kind == ForKind::kParallel) {
894 if (thread_scope_ == nullptr || thread_scope_ == op) {
895 PlanNewScope(op);
896 }
897 }
898 }
899 // scope_pair_offset <= 0 means it is either
900 // - leaf stmt(offset = 0)
901 // - end of scope(offset < 0)
902 // In both cases, we need to handle the kill event correctly
903 if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) {
904 for (const VarNode* var : it->second.kill) {
905 // skip space which are already replaced by inplace
906 if (!inplace_flag.count(var)) {
907 this->Free(var);
908 }
909 }
910 }
911 }
912 }
913 // Allocate new storage entry.
914 StorageEntry* NewAlloc(const AllocateNode* op, const Object* attach_scope,
915 const StorageScope& scope, size_t const_nbits) {
916 ICHECK(op != nullptr);
917 // Re-use not successful, allocate a new buffer.
918 auto entry = std::make_unique<StorageEntry>();
919 entry->attach_scope_ = attach_scope;
920 entry->scope = scope;
921 entry->elem_type = op->dtype.element_of();
922 entry->const_nbits = const_nbits;
923 StorageEntry* e = entry.get();
924 alloc_vec_.emplace_back(std::move(entry));
925 return e;
926 }
927
928 StorageEntry* FindAlloc(const AllocateNode* op, const Object* attach_scope,
929 const StorageScope& scope, size_t num_physical_dimensions) {
930 ICHECK(op != nullptr);
931 // skip plan for local variable,
932 // compiler can do a better job with register allocation.
933 const uint64_t match_range = 16;
934 uint64_t op_elem_bits = op->dtype.bits() * op->dtype.lanes();
935 uint64_t const_nbits = static_cast<uint64_t>(op->ConstantAllocationSize() * op_elem_bits);
936
937 // If the size of the array isn't known at compile-time, it must
938 // have its own allocation with size determined at runtime.
939 bool is_known_size = (const_nbits != 0);
940
941 // Currently, only flat memory spaces can be re-used. Packing
942 // into N-d space (e.g. 2-d texture memory on GPUs) will require
943 // more in-depth algorithms.
944 bool is_flat_memory_space = (num_physical_dimensions == 1);
945
946 // disable reuse of small arrays, they will be lowered to registers in LLVM
947 // This rules only apply if we are using non special memory
948 bool is_small_array =
949 (scope.tag.length() == 0) && (scope.rank >= StorageRank::kWarp || op->dtype.is_handle() ||
950 (is_known_size && const_nbits <= 32));
951
952 if (is_small_array || !is_flat_memory_space) {
953 return NewAlloc(op, attach_scope, scope, const_nbits);
954 }
955
956 if (is_known_size) {
957 // constant allocation.
958 auto begin = const_free_map_.lower_bound(const_nbits / match_range);
959 auto mid = const_free_map_.lower_bound(const_nbits);
960 auto end = const_free_map_.upper_bound(const_nbits * match_range);
961 // start looking at the buffer that is bigger than the required size first
962 for (auto it = mid; it != end; ++it) {
963 StorageEntry* e = it->second;
964 if (e->attach_scope_ != attach_scope) continue;
965 if (e->scope != scope) continue;
966 // when not divided, no reuse, eg, float4 vs float3
967 if (e->bits_offset % op_elem_bits != 0) continue;
968 e->const_nbits = std::max(const_nbits, e->const_nbits);
969 const_free_map_.erase(it);
970 return e;
971 }
972 // then start looking at smaller buffers.
973 for (auto it = mid; it != begin;) {
974 --it;
975 StorageEntry* e = it->second;
976 if (e->attach_scope_ != attach_scope) continue;
977 if (e->scope != scope) continue;
978 if (e->elem_type != op->dtype.element_of()) continue;
979 e->const_nbits = std::max(const_nbits, e->const_nbits);
980 const_free_map_.erase(it);
981 return e;
982 }
983 } else {
984 // Simple strategy: round roubin.
985 for (auto it = sym_free_list_.begin(); it != sym_free_list_.end(); ++it) {
986 StorageEntry* e = *it;
987 if (e->attach_scope_ != attach_scope) continue;
988 if (e->scope != scope) continue;
989 if (e->elem_type != op->dtype.element_of()) continue;
990 sym_free_list_.erase(it);
991 return e;
992 }
993 }
994 return NewAlloc(op, attach_scope, scope, const_nbits);
995 }
996 // simulated free.
997 void Free(const VarNode* var) {
998 auto it = alloc_map_.find(var);
999 ICHECK(it != alloc_map_.end());
1000 StorageEntry* e = it->second;
1001 ICHECK_NE(e->allocs.size(), 0U);
1002
1003 // disable reuse of small arrays, they will be lowered to registers in LLVM
1004 // This rules only apply if we are using non special memory
1005 if (e->scope.tag.length() == 0) {
1006 // Disable sharing of local memory.
1007 if (e->scope.rank >= StorageRank::kWarp || e->allocs[0]->dtype.is_handle()) return;
1008 // disable reuse of small arrays
1009 if (e->const_nbits > 0 && e->const_nbits <= 32) return;
1010 }
1011 // normal free.
1012 if (e->const_nbits != 0) {
1013 const_free_map_.insert({e->const_nbits, e});
1014 } else {
1015 sym_free_list_.push_back(e);
1016 }
1017 }
1018 // thread scope.
1019 const Object* thread_scope_{nullptr};
1020 // whether enable inplace detection.
1021 bool detect_inplace_{false};
1022 // Locations of free ops.
1023 std::unordered_map<const Object*, EventEntry> event_map_;
1024 // constant size free map.
1025 std::multimap<uint64_t, StorageEntry*> const_free_map_;
1026 // symbolic free list, for non constant items.
1027 std::list<StorageEntry*> sym_free_list_;
1028 // The allocation attach map
1029 std::unordered_map<const Object*, std::vector<StorageEntry*>> attach_map_;
1030 // The allocation assign map
1031 std::unordered_map<const VarNode*, StorageEntry*> alloc_map_;
1032 // The allocations
1033 std::vector<std::unique_ptr<StorageEntry>> alloc_vec_;
1034 // The buffer objects being remapped
1035 std::unordered_map<const BufferNode*, Buffer> buffer_remap_;
1036 // analyzer
1037 arith::Analyzer analyzer_;
1038};
1039
1040/* Helper struct containing information on how a buffer is declared and used
1041 *
1042 */
1043struct BufferVarInfo {
1044 enum DeclarationLocation {
1045 kPrimFuncParam = (1 << 0),
1046 kPrimFuncBufferMap = (1 << 1),
1047 kAllocateNode = (1 << 2),
1048 kAllocateConstNode = (1 << 3),
1049 kLetNode = (1 << 4),
1050 };
1051
1052 // The tir::Var that represents this buffer.
1053 Var var;
1054
1055 // The data type of an element of the buffer.
1056 DataType element_dtype;
1057
1058 /* The extent of the buffer.
1059 *
1060 * If multidimensional, the extent of the last dimension of the buffer. If the
1061 * size is unknown (e.g. pointer arguments to PrimFunc with no corresponding
1062 * entry in buffer_map), then extent is zero.
1063 */
1064 PrimExpr extent;
1065
1066 // Where the buffer was declared
1067 DeclarationLocation declaration_location;
1068
1069 // When accessed, which element type is it accessed as. This may
1070 // differ both in base type (e.g. int32* cast to float32* after
1071 // packing in StorageRewrite) or in number of lanes (e.g. float16*
1072 // cast to float16x4*).
1073 std::unordered_set<DataType> access_dtype;
1074
1075 DataType get_preferred_dtype() const {
1076 std::unordered_set<DataType> base_access_dtype;
1077 for (auto dtype : access_dtype) {
1078 base_access_dtype.insert(dtype.element_of());
1079 }
1080 // If the array is accessed as multiple base types within a
1081 // function, no point in changing the declared type. CodeGenC can
1082 // handle this with a type-cast prior to indexing. Vulkan will
1083 // raise an error at code-gen time, if a later pass doesn't split
1084 // it out.
1085 if (base_access_dtype.size() != 1) {
1086 return element_dtype;
1087 }
1088
1089 DataType preferred_base_type = *base_access_dtype.begin();
1090
1091 // If there is only one vectorizable size used to access the
1092 // buffer, and if that access size is compatible with the array
1093 // size, then the buffer is vectorizable. In the future, this
1094 // could be improved to allow vectorized buffer access of size
1095 // GCD(*lanes_used), if necessary.
1096 int preferred_lanes = element_dtype.lanes();
1097 if ((element_dtype.lanes() == 1) && (access_dtype.size() == 1)) {
1098 arith::Analyzer analyzer_;
1099 arith::ModularSet me = analyzer_.modular_set(extent);
1100
1101 int lanes = access_dtype.begin()->lanes();
1102 if ((me->coeff % lanes == 0) && (me->base % lanes == 0)) {
1103 preferred_lanes = lanes;
1104 }
1105 }
1106
1107 return preferred_base_type.with_lanes(preferred_lanes);
1108 }
1109};
1110
1111/* Checks whether buffers are accessed as scalar or vector parameters in a
1112 * function.
1113 *
1114 */
1115class VectorTypeAccessChecker : public StmtExprVisitor {
1116 public:
1117 /* Constructor
1118 *
1119 * @param params The parameters passed to a PrimFunc
1120 *
1121 * @param buffer_map The buffer_map associated with a PrimFunc
1122 *
1123 * @param allow_untyped_handles If a buffer or pointer variable is
1124 * missing a type annotation, assume that it has the same underlying
1125 * type as it is later accessed, with scalar element types.
1126 */
1127 VectorTypeAccessChecker(const Array<tir::Var>& params, const Map<Var, Buffer>& buffer_map,
1128 bool allow_untyped_pointers = false)
1129 : allow_untyped_pointers_(allow_untyped_pointers) {
1130 // If a parameter is in the buffer map, we want to track the
1131 // version in the map.
1132 for (auto it : buffer_map) {
1133 Buffer& buffer = it.second;
1134 Var buffer_var = buffer->data;
1135 DataType dtype = buffer->dtype;
1136 PrimExpr extent = buffer->shape.size() ? buffer->shape[buffer->shape.size() - 1] : 0;
1137 OnArrayDeclaration(buffer_var, dtype, extent, BufferVarInfo::kPrimFuncParam);
1138 }
1139
1140 // If a pointer parameter isn't in the buffer map, then we want to
1141 // track the parameter itself.
1142 for (Var buffer_var : params) {
1143 auto pointer_type = GetPointerType(buffer_var->type_annotation);
1144 if (pointer_type.has_value() && (buffer_map.count(buffer_var) == 0)) {
1145 DataType dtype = pointer_type.value();
1146 PrimExpr extent = 0;
1147 OnArrayDeclaration(buffer_var, dtype, extent, BufferVarInfo::kPrimFuncBufferMap);
1148 }
1149 }
1150 }
1151
1152 void VisitExpr_(const LoadNode* op) final {
1153 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
1154 }
1155
1156 void VisitStmt_(const StoreNode* op) final {
1157 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
1158 }
1159
1160 void VisitExpr_(const BufferLoadNode* op) final {
1161 OnArrayAccess(op->dtype, op->buffer->data.get(), op->indices);
1162 StmtExprVisitor::VisitExpr_(op);
1163 }
1164
1165 void VisitStmt_(const BufferStoreNode* op) final {
1166 OnArrayAccess(op->value.dtype(), op->buffer->data.get(), op->indices);
1167 StmtExprVisitor::VisitStmt_(op);
1168 }
1169
1170 void VisitExpr_(const CallNode* op) final {
1171 if (op->op.same_as(builtin::tvm_access_ptr())) {
1172 DataType dtype = op->args[0].dtype();
1173 const VarNode* buffer = op->args[1].as<VarNode>();
1174 PrimExpr index = op->args[2];
1175 OnArrayAccess(dtype, buffer, {index});
1176 }
1177 StmtExprVisitor::VisitExpr_(op);
1178 }
1179
1180 void VisitStmt_(const AllocateNode* op) final {
1181 const Array<PrimExpr>& extents = op->extents;
1182 PrimExpr extent = extents[extents.size() - 1];
1183 OnArrayDeclaration(op->buffer_var, op->dtype, extent, BufferVarInfo::kAllocateNode);
1184
1185 StmtExprVisitor::VisitStmt_(op);
1186 }
1187
1188 void VisitStmt_(const AllocateConstNode* op) final {
1189 const Array<PrimExpr>& extents = op->extents;
1190 PrimExpr extent = extents.size() ? extents[extents.size() - 1] : NullValue<PrimExpr>();
1191 OnArrayDeclaration(op->buffer_var, op->dtype, extent, BufferVarInfo::kAllocateConstNode);
1192
1193 StmtExprVisitor::VisitStmt_(op);
1194 }
1195
1196 void VisitExpr_(const LetNode* op) final {
1197 HandleLetNode(op->var);
1198 StmtExprVisitor::VisitExpr_(op);
1199 }
1200
1201 void VisitStmt_(const LetStmtNode* op) final {
1202 HandleLetNode(op->var);
1203 StmtExprVisitor::VisitStmt_(op);
1204 }
1205
1206 void HandleLetNode(Var let_var) {
1207 if (let_var->dtype.is_handle()) {
1208 auto pointer_type = GetPointerType(let_var->type_annotation);
1209 if (pointer_type.has_value()) {
1210 OnArrayDeclaration(let_var, pointer_type.value(), 0, BufferVarInfo::kLetNode);
1211 } else if (allow_untyped_pointers_) {
1212 OnArrayDeclaration(let_var, let_var->dtype, 0, BufferVarInfo::kLetNode);
1213 } else {
1214 LOG(FATAL) << "Let statement of variable " << let_var->name_hint
1215 << " is missing a type annotation, "
1216 << "or type annotation is not a pointer to primitive";
1217 }
1218 }
1219 }
1220
1221 /* Update the type map for a buffer based on its declaration
1222 *
1223 * @param buffer The VarNode representing the buffer.
1224 *
1225 * @param element_dtype The dtype of a single element of the buffer.
1226 * If unknown, when used with the allow_untyped_handles option,
1227 * should be a handle dtype.
1228 *
1229 * @param extent The extent of the buffer. Zero if size is unknown.
1230 *
1231 * @param declaration_location How the buffer was allocated, so that
1232 * some locations can be rewritten without others.
1233 */
1234 void OnArrayDeclaration(Var buffer, DataType element_dtype, PrimExpr extent,
1235 BufferVarInfo::DeclarationLocation declaration_location) {
1236 ICHECK(info_map_.find(buffer.get()) == info_map_.end())
1237 << "Array declaration of " << buffer->name_hint << " occurred multiple times.";
1238
1239 if (element_dtype == DataType::Bool()) {
1240 element_dtype = DataType::Int(8).with_lanes(element_dtype.lanes());
1241 }
1242
1243 info_map_[buffer.get()] = {buffer, element_dtype, extent, declaration_location};
1244 }
1245
1246 /* Update the type map for a buffer based on its usage
1247 *
1248 * @param value_dtype The dtype of the value being stored to or
1249 * loaded from the buffer.
1250 *
1251 * @param buffer The VarNode representing the buffer.
1252 *
1253 * @param index The index at which the value is being stored/loaded.
1254 *
1255 * @param predicate The predicate used for the store/load.
1256 */
1257 void OnArrayAccess(DataType value_dtype, const VarNode* buffer, const Array<PrimExpr>& indices) {
1258 auto it = info_map_.find(buffer);
1259 ICHECK(it != info_map_.end()) << "Load/Store of buffer " << buffer->name_hint << " (" << buffer
1260 << ") occurred before its declaration.";
1261 BufferVarInfo& var_info = it->second;
1262
1263 if (value_dtype.element_of() == DataType::Bool()) {
1264 value_dtype = DataType::Int(8).with_lanes(value_dtype.lanes());
1265 }
1266
1267 if (var_info.element_dtype.is_handle()) {
1268 ICHECK(allow_untyped_pointers_) << "Variable " << buffer->name_hint
1269 << " was missing a type annotation in its declaration";
1270 var_info.element_dtype = value_dtype.element_of();
1271 }
1272
1273 for (int i = 0; i < static_cast<int>(indices.size()) - 1; i++) {
1274 ICHECK(indices[i].dtype().is_scalar())
1275 << "Only the last index of a buffer access may be a vector type.";
1276 }
1277 int index_lanes = indices.size() ? indices.back().dtype().lanes() : 1;
1278
1279 DataType access_dtype = value_dtype;
1280
1281 int lanes_used = var_info.element_dtype.lanes();
1282
1283 // This can happen due to a previous pass that had rewrite_store_load =
1284 // false. This occurs from the StorageRewrite in tvm::lower, followed by the
1285 // PointerValueTypeRewrite in BuildSPIRV. The rewrite_store_load = false is
1286 // necessary because the C-based codegens do not yet support vectorized
1287 // pointer types (e.g. float16x4*). Once they do, this if statement should
1288 // instead be replaced by the below ICHECK_EQ.
1289 if (index_lanes * var_info.element_dtype.lanes() != value_dtype.lanes()) {
1290 ICHECK_EQ(index_lanes, value_dtype.lanes());
1291 lanes_used = 1;
1292 var_info.element_dtype = var_info.element_dtype.with_lanes(1);
1293 }
1294
1295 // TODO(Lunderberg): Uncomment this check once it can be applied.
1296 // See https://discuss.tvm.apache.org/t/pre-rfc-vectorized-tir-buffers/10615
1297 // for discussion.
1298
1299 // ICHECK_EQ(index_lanes * var_info.element_dtype.lanes(), value_dtype.lanes())
1300 // << "Attempting to retrieve " << value_dtype.lanes() << " lanes of data with "
1301 // << index_lanes << " indices into an array whose elements have "
1302 // << var_info.element_dtype.lanes() << " lanes. "
1303 // << "Expected output with " << index_lanes * var_info.element_dtype.lanes()
1304 // << " lanes.";
1305
1306 // If the index is a RampNode with stride of 1 and offset
1307 // divisible by the number of number of lanes, and the predicate
1308 // does not apply any masking, then this array access could be
1309 // vectorized.
1310 if (indices.size()) {
1311 const RampNode* ramp_index = indices[indices.size() - 1].as<RampNode>();
1312 if (ramp_index && is_one(ramp_index->stride)) {
1313 arith::ModularSet me = analyzer_.modular_set(ramp_index->base);
1314 if ((me->coeff % ramp_index->lanes == 0) && (me->base % ramp_index->lanes == 0)) {
1315 lanes_used = ramp_index->lanes;
1316 }
1317 }
1318 }
1319
1320 var_info.access_dtype.insert(access_dtype.with_lanes(lanes_used));
1321 }
1322
1323 // Map of buffer variable information determined
1324 std::unordered_map<const VarNode*, BufferVarInfo> info_map_;
1325
1326 //
1327 bool allow_untyped_pointers_{false};
1328
1329 // internal analyzer
1330 arith::Analyzer analyzer_;
1331};
1332
1333/* \brief Rewrites buffer/pointer variables from scalar types to vectorized
1334 * types.
1335 *
1336 * Some runtimes do not allow casting between composite types and the underlying
1337 * base type (e.g. Vulkan, casting from 1-lane float16* to 4-lane float16x4*).
1338 * In these cases, in order to have vectorized load/store on an array, the
1339 * element type of that array must be vectorized. This is in contrast to C-style
1340 * runtimes, in which `float16x4* vec = *(float16x4*)(float_arr + offset)` is
1341 * valid.
1342 *
1343 * By default, VectorTypeRewriter will attempt to rewrite all buffer variables to
1344 * vectorized access, if the load/store occurring in the PrimFunc are all
1345 * vectorized. This includes adjusting the indices being used to access the
1346 * array. (e.g. If `float16* scalar_arr` is being converted to `float16x4*
1347 * vec_arr`, then `scalar_arr[Ramp(offset, 1, 4)]` will be converted to
1348 * `vec_arr[offset/4]`.)
1349 *
1350 * Currently, several of the C-style runtimes do not support buffers whose
1351 * elements are vectorized types, or rely on the presence of the Ramp nodes to
1352 * identify vectorized loads. The boolean parameters in the constructor are to
1353 * mimic the previous behavior of VectorTypeRewriter, to avoid breaking these
1354 * runtimes. Once all runtimes support vectorized buffer elements, these
1355 * parameters can be removed.
1356 */
1357class VectorTypeRewriter : public StmtExprMutator {
1358 public:
1359 /* Constructor
1360 *
1361 * @param checker The VectorTypeAccessChecker that has previously read out
1362 * information from the PrimFunc
1363 *
1364 * @param rewrite_params Whether pointer-type parameters passed into the
1365 * function should be rewritten from scalar types to vectorized types.
1366 *
1367 * @param rewrite_buffer_map Whether buffers present in the buffer_map should
1368 * have their data variable be rewritten from scalar types to vectorized types.
1369 *
1370 * @param rewrite_allocate_node Whether the buffer variable associated with
1371 * AllocateNodes should be rewritten from scalar types to vectorized types.
1372 *
1373 * @param rewrite_indices Whether the indices to the Load and Store nodes
1374 * should be rewritten to correspond to the new buffer_var type.
1375 *
1376 * @param rewrite_let_node Whether pointer declarations in let nodes
1377 * should be re-written.
1378 */
1379 VectorTypeRewriter(const std::unordered_map<const VarNode*, BufferVarInfo>& info_map,
1380 bool rewrite_params = true, bool rewrite_buffer_map = true,
1381 bool rewrite_allocate_node = true, bool rewrite_indices = true,
1382 bool rewrite_let_node = true, bool rewrite_allocate_const_node = true)
1383 : rewrite_indices_(rewrite_indices) {
1384 int rewrite_mask = 0;
1385 if (rewrite_params) {
1386 rewrite_mask |= BufferVarInfo::kPrimFuncParam;
1387 }
1388 if (rewrite_buffer_map) {
1389 rewrite_mask |= BufferVarInfo::kPrimFuncBufferMap;
1390 }
1391 if (rewrite_allocate_node) {
1392 rewrite_mask |= BufferVarInfo::kAllocateNode;
1393 }
1394 if (rewrite_let_node) {
1395 rewrite_mask |= BufferVarInfo::kLetNode;
1396 }
1397 if (rewrite_allocate_const_node) {
1398 rewrite_mask |= BufferVarInfo::kAllocateConstNode;
1399 }
1400
1401 // Rewrite any buffer variables whose preferred type isn't their current type.
1402 for (const auto& pair : info_map) {
1403 const auto& var_info = pair.second;
1404 DataType preferred = var_info.get_preferred_dtype();
1405 if (preferred != var_info.element_dtype && (rewrite_mask & var_info.declaration_location)) {
1406 Var old_buffer_var = var_info.var;
1407 Var new_buffer_var(old_buffer_var->name_hint,
1408 PointerType(PrimType(preferred), GetPtrStorageScope(old_buffer_var)),
1409 old_buffer_var->span);
1410
1411 rewrite_map_[var_info.var.get()] = {var_info.var, new_buffer_var, var_info.element_dtype,
1412 preferred};
1413 }
1414 }
1415 }
1416
1417 PrimExpr VisitExpr_(const LoadNode* op) final {
1418 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
1419 }
1420
1421 Stmt VisitStmt_(const StoreNode* op) final {
1422 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
1423 }
1424
1425 template <typename Node>
1426 Node VisitBufferAccess(Node node) {
1427 if (!rewrite_indices_) {
1428 return node;
1429 }
1430
1431 auto it = rewrite_map_.find(node->buffer->data.get());
1432 if (it == rewrite_map_.end()) {
1433 return node;
1434 }
1435 const auto& info = it->second;
1436
1437 Array<PrimExpr> indices = node->indices;
1438
1439 const RampNode* ramp_index = indices[indices.size() - 1].as<RampNode>();
1440 if (ramp_index && is_one(ramp_index->stride)) {
1441 PrimExpr new_index =
1442 ramp_index->base / make_const(ramp_index->base.dtype(), ramp_index->lanes);
1443 if (ramp_index->lanes != info.factor()) {
1444 new_index = Ramp(new_index, ramp_index->stride, ramp_index->lanes / info.factor(),
1445 ramp_index->span);
1446 }
1447
1448 indices.Set(indices.size() - 1, new_index);
1449 }
1450
1451 auto writer = node.CopyOnWrite();
1452 writer->buffer = RemapBuffer(node->buffer);
1453 writer->indices = indices;
1454
1455 return node;
1456 }
1457
1458 PrimExpr VisitExpr_(const BufferLoadNode* op) final {
1459 auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
1460 auto modified = VisitBufferAccess(node);
1461
1462 // Not needed for BufferStoreNode, so we can't just call
1463 // LegalizeDtype() in VisitBufferAccess.
1464 if (node.same_as(modified)) {
1465 return std::move(node);
1466 } else {
1467 auto writer = modified.CopyOnWrite();
1468 writer->LegalizeDType();
1469 return std::move(modified);
1470 }
1471 }
1472
1473 Stmt VisitStmt_(const BufferStoreNode* op) final {
1474 auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
1475 return VisitBufferAccess(std::move(node));
1476 }
1477
1478 Stmt VisitStmt_(const LetStmtNode* op) final {
1479 auto it = rewrite_map_.find(op->var.get());
1480 PrimExpr value = this->VisitExpr(op->value);
1481 Stmt body = this->VisitStmt(op->body);
1482 Var var = (it == rewrite_map_.end()) ? op->var : it->second.new_buffer_var;
1483 if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) {
1484 return GetRef<Stmt>(op);
1485 }
1486 return LetStmt(var, value, body);
1487 }
1488
1489 Buffer RemapBuffer(Buffer buf) {
1490 auto cache_key = buf.get();
1491
1492 auto cache_it = buffer_map_.find(cache_key);
1493 if (cache_it != buffer_map_.end()) {
1494 return cache_it->second;
1495 }
1496
1497 auto info_it = rewrite_map_.find(buf->data.get());
1498 if (info_it != rewrite_map_.end()) {
1499 auto& info = info_it->second;
1500
1501 Array<PrimExpr> shape = buf->shape;
1502 PrimExpr last_dim = shape[shape.size() - 1];
1503 shape.Set(shape.size() - 1, last_dim / make_const(last_dim.dtype(), info.factor()));
1504
1505 auto writer = buf.CopyOnWrite();
1506 writer->data = info.new_buffer_var;
1507 writer->dtype = info.new_element_dtype;
1508 writer->shape = shape;
1509 }
1510
1511 buffer_map_[cache_key] = buf;
1512 return buf;
1513 }
1514
1515 PrimExpr VisitExpr_(const CallNode* op) final {
1516 if (op->op.same_as(builtin::tvm_access_ptr())) {
1517 PrimExpr expr = StmtExprMutator::VisitExpr_(op);
1518 op = expr.as<CallNode>();
1519
1520 if (!rewrite_indices_) {
1521 return expr;
1522 }
1523
1524 const VarNode* buffer_var = op->args[1].as<VarNode>();
1525 auto it = rewrite_map_.find(buffer_var);
1526 if (it == rewrite_map_.end()) {
1527 return expr;
1528 }
1529 const auto& info = it->second;
1530
1531 PrimExpr index = op->args[2];
1532 PrimExpr extent = op->args[3];
1533 PrimExpr flag = op->args[4];
1534
1535 PrimExpr e_dtype = tir::TypeAnnotation(info.new_element_dtype);
1536 int factor = info.factor();
1537 extent = extent / make_const(extent.dtype(), factor);
1538 index = index / make_const(index.dtype(), factor);
1539 Array<PrimExpr> acc_args{e_dtype, info.new_buffer_var, index, extent, flag};
1540 return Call(info.new_element_dtype, builtin::tvm_access_ptr(), acc_args);
1541
1542 } else {
1543 return StmtExprMutator::VisitExpr_(op);
1544 }
1545 }
1546
1547 Stmt VisitStmt_(const AllocateNode* op) final {
1548 Stmt stmt = StmtExprMutator::VisitStmt_(op);
1549 op = stmt.as<AllocateNode>();
1550
1551 auto it = rewrite_map_.find(op->buffer_var.get());
1552 if (it == rewrite_map_.end()) {
1553 return stmt;
1554 }
1555
1556 const auto& info = it->second;
1557
1558 Var new_buffer_var = info.new_buffer_var;
1559
1560 Array<PrimExpr> extents = op->extents;
1561 PrimExpr last_extent = extents[extents.size() - 1];
1562 extents.Set(extents.size() - 1, last_extent / make_const(last_extent.dtype(), info.factor()));
1563 return Allocate(new_buffer_var, info.new_element_dtype, extents, op->condition, op->body);
1564 }
1565
1566 Stmt VisitStmt_(const AllocateConstNode* op) final {
1567 Stmt stmt = StmtExprMutator::VisitStmt_(op);
1568 op = stmt.as<AllocateConstNode>();
1569
1570 auto it = rewrite_map_.find(op->buffer_var.get());
1571 if (it == rewrite_map_.end()) {
1572 return stmt;
1573 }
1574
1575 const auto& info = it->second;
1576
1577 Var new_buffer_var = info.new_buffer_var;
1578
1579 int factor = info.new_element_dtype.lanes() / op->dtype.lanes();
1580
1581 Array<PrimExpr> extents = op->extents;
1582 extents.Set(extents.size() - 1,
1583 extents[extents.size() - 1] / make_const(extents[0].dtype(), factor));
1584 return AllocateConst(new_buffer_var, info.new_element_dtype, extents, op->data, op->body);
1585 }
1586
1587 /* Update the parameters and all remaining variable references
1588 *
1589 * Should be called after calling operator() on the body of the
1590 * function.
1591 *
1592 * @param func A pointer to the PrimFunc being modified.
1593 */
1594 void Finalize(PrimFunc* func_ptr) {
1595 ICHECK(func_ptr) << "Finalize expects a non-null pointer";
1596 auto& func = *func_ptr;
1597 auto* n = func.CopyOnWrite();
1598
1599 // Remap any remaining references to the old buffer variables
1600 Map<Var, PrimExpr> var_remap;
1601 for (const auto& pair : rewrite_map_) {
1602 const auto& info = pair.second;
1603 var_remap.Set(info.old_buffer_var, info.new_buffer_var);
1604 }
1605 n->body = Substitute(n->body, var_remap);
1606
1607 // Remap the argument list to use the new buffer variables.
1608 Array<Var> new_params;
1609 for (const auto& old_param : n->params) {
1610 auto it = rewrite_map_.find(old_param.get());
1611 if (it == rewrite_map_.end()) {
1612 new_params.push_back(old_param);
1613 } else {
1614 const auto& info = it->second;
1615 new_params.push_back(info.new_buffer_var);
1616 }
1617 }
1618 n->params = new_params;
1619
1620 // Remap the Buffer objects in PrimFunc::buffer_map so that the
1621 // buffers use the new buffer variables
1622 Map<Var, Buffer> new_buffer_map;
1623 for (const auto& pair : n->buffer_map) {
1624 Var key = pair.first;
1625 Buffer old_buffer = pair.second;
1626 Var old_var = old_buffer->data;
1627 Buffer new_buffer = RemapBuffer(old_buffer);
1628 new_buffer_map.Set(key, new_buffer);
1629 }
1630 n->buffer_map = new_buffer_map;
1631 }
1632
1633 private:
1634 struct RewriteInfo {
1635 Var old_buffer_var;
1636 Var new_buffer_var;
1637 DataType old_element_dtype;
1638 DataType new_element_dtype;
1639
1640 int factor() const {
1641 int old_lanes = old_element_dtype.lanes();
1642 int new_lanes = new_element_dtype.lanes();
1643 ICHECK_EQ(new_lanes % old_lanes, 0);
1644 return new_lanes / old_lanes;
1645 }
1646 };
1647
1648 bool rewrite_indices_{true};
1649 std::unordered_map<const VarNode*, RewriteInfo> rewrite_map_;
1650 std::unordered_map<const BufferNode*, Buffer> buffer_map_;
1651};
1652
1653// Rewrite allocates, pointer parameters, and buffer map into vectorized versions
1654// if each access into a buffer is the same vector type.
1655PrimFunc PointerValueTypeRewrite(PrimFunc f, bool allow_untyped_pointers = false,
1656 bool rewrite_params = true, bool rewrite_buffer_map = true,
1657 bool rewrite_allocate_node = true, bool rewrite_indices = true,
1658 bool rewrite_let_node = true,
1659 bool rewrite_allocate_const_node = true) {
1660 VectorTypeAccessChecker checker(f->params, f->buffer_map, allow_untyped_pointers);
1661 checker(f->body);
1662
1663 VectorTypeRewriter rewriter(checker.info_map_, rewrite_params, rewrite_buffer_map,
1664 rewrite_allocate_node, rewrite_indices, rewrite_let_node,
1665 rewrite_allocate_const_node);
1666 PrimFuncNode* n = f.CopyOnWrite();
1667 n->body = rewriter(std::move(n->body));
1668 rewriter.Finalize(&f);
1669
1670 return f;
1671}
1672
1673namespace transform {
1674
1675Pass StorageRewrite() {
1676 auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
1677 auto* n = f.CopyOnWrite();
1678 n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true);
1679 // Parameters may not be rewritten, but internal allocations may.
1680 // Vectorization of AllocateConst is currently disabled, as it has
1681 // indexing issues for types that include padding (e.g. int8x3
1682 // padded out to 32 bits) would require either rewriting
1683 // AllocateConst::data, or would require the code generators to
1684 // handle vectorized constants.
1685 return PointerValueTypeRewrite(std::move(f), true, false, false, true, true, true, false);
1686 };
1687 return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {});
1688}
1689
1690TVM_REGISTER_GLOBAL("tir.transform.StorageRewrite").set_body_typed(StorageRewrite);
1691
1692Pass PointerValueTypeRewrite() {
1693 auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
1694 return PointerValueTypeRewrite(std::move(f));
1695 };
1696 return CreatePrimFuncPass(pass_func, 0, "tir.PointerValueTypeRewrite", {});
1697}
1698
1699TVM_REGISTER_GLOBAL("tir.transform.PointerValueTypeRewrite")
1700 .set_body_typed(PointerValueTypeRewrite);
1701
1702} // namespace transform
1703
1704} // namespace tir
1705} // namespace tvm
1706