1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file storage_rewrite.cc |
22 | * \brief Memory access pattern analysis and optimization. |
23 | * Re-write data access to enable memory sharing when possible. |
24 | */ |
25 | #include <tvm/arith/analyzer.h> |
26 | #include <tvm/ir/type.h> |
27 | #include <tvm/runtime/registry.h> |
28 | #include <tvm/target/target_info.h> |
29 | #include <tvm/tir/analysis.h> |
30 | #include <tvm/tir/builtin.h> |
31 | #include <tvm/tir/expr.h> |
32 | #include <tvm/tir/stmt_functor.h> |
33 | #include <tvm/tir/transform.h> |
34 | |
35 | #include <map> |
36 | #include <unordered_map> |
37 | #include <unordered_set> |
38 | |
39 | #include "../../runtime/thread_storage_scope.h" |
40 | #include "../ir/buffer_common.h" |
41 | #include "ir_utils.h" |
42 | |
43 | namespace tvm { |
44 | namespace tir { |
45 | |
46 | using runtime::StorageRank; |
47 | using runtime::StorageScope; |
48 | |
49 | // Find a linear pattern of storage access |
50 | // Used for liveness analysis. |
51 | // Composite scopes(loop/thread_launch/IfThen) is represented by two points: |
52 | // before_scope -> scope_body -> after_scope |
53 | // |
54 | // The linear_seq_ stores before_scope and after_scope. |
55 | // The access to the arrays are stored at the after_scope point. |
56 | // |
57 | // Define "scope" as the body of For/thread_launch/IfThenElse |
58 | // This pass tries to detect last point that we need to keep memory |
59 | // alive under the same scope as allocate. |
60 | // The storage need to be kept alive between allocate and last access. |
61 | // The free point is only inserted at the same scope of allocate. |
62 | // |
63 | class LinearAccessPatternFinder final : public StmtExprVisitor { |
64 | public: |
65 | /*! \brief record the touch hist of statment. */ |
66 | struct StmtEntry { |
67 | // The statment |
68 | const Object* stmt; |
69 | // The index in the linear_seq_ to point to end of the nested scope. |
70 | // This is only set to non-zero if stmt is a nested scope. |
71 | // if offset > 0, means this is the begin, the end entry is current_index + offset |
72 | // if offset < 0, means this is the end, the begin entry is current_index + offset |
73 | int64_t scope_pair_offset{0}; |
74 | // The buffer variables this statment touched. |
75 | std::vector<const VarNode*> touched; |
76 | }; |
77 | // The scope of each allocation |
78 | struct AllocEntry { |
79 | // The physical dimension of the allocation. |
80 | size_t num_physical_dimensions{0}; |
81 | // scope level |
82 | size_t level{0}; |
83 | // allocation stmt |
84 | const AllocateNode* alloc{nullptr}; |
85 | }; |
86 | |
87 | void VisitStmt_(const AllocateNode* op) final { |
88 | size_t level = scope_.size(); |
89 | const VarNode* buf = op->buffer_var.get(); |
90 | |
91 | AllocEntry entry; |
92 | entry.alloc = op; |
93 | entry.level = level; |
94 | // Since StorageRewrite occurs after StorageFlatten/FlattenBuffer, |
95 | // all allocations specify the extent of physical dimensions, and |
96 | // is 1 for flat memory spaces. |
97 | entry.num_physical_dimensions = op->extents.size(); |
98 | alloc_info_[buf] = entry; |
99 | |
100 | StmtExprVisitor::VisitStmt_(op); |
101 | } |
102 | |
103 | void VisitStmt_(const StoreNode* op) final { |
104 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
105 | } |
106 | |
107 | void VisitStmt_(const BufferStoreNode* op) final { |
108 | scope_.push_back(StmtEntry()); |
109 | // visit subexpr |
110 | StmtExprVisitor::VisitStmt_(op); |
111 | // Add write access. |
112 | const VarNode* buf = op->buffer->data.get(); |
113 | auto it = alloc_info_.find(buf); |
114 | if (it != alloc_info_.end() && it->second.alloc) { |
115 | ICHECK_LT(it->second.level, scope_.size()); |
116 | scope_[it->second.level].touched.push_back(buf); |
117 | |
118 | ICHECK_EQ(op->buffer->axis_separators.size() + 1, it->second.num_physical_dimensions) |
119 | << "Buffer " << op->buffer->name << " is allocated with " |
120 | << it->second.num_physical_dimensions |
121 | << " physical dimensions, but is accessed as having " |
122 | << op->buffer->axis_separators.size() + 1 << " physical dimensions" << std::endl; |
123 | } |
124 | StmtEntry e = scope_.back(); |
125 | scope_.pop_back(); |
126 | if (e.touched.size() != 0) { |
127 | e.stmt = op; |
128 | linear_seq_.push_back(e); |
129 | } |
130 | } |
131 | |
132 | void VisitExpr_(const LoadNode* op) final { |
133 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
134 | } |
135 | |
136 | void VisitExpr_(const BufferLoadNode* op) final { |
137 | // Add write access. |
138 | StmtExprVisitor::VisitExpr_(op); |
139 | const VarNode* buf = op->buffer->data.get(); |
140 | auto it = alloc_info_.find(buf); |
141 | if (it != alloc_info_.end() && it->second.alloc) { |
142 | ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store." ; |
143 | scope_[it->second.level].touched.push_back(buf); |
144 | |
145 | ICHECK_EQ(op->buffer->axis_separators.size() + 1, it->second.num_physical_dimensions) |
146 | << "Buffer " << op->buffer->name << " is allocated with " |
147 | << it->second.num_physical_dimensions |
148 | << " physical dimensions, but is accessed as having " |
149 | << op->buffer->axis_separators.size() + 1 << " physical dimensions" << std::endl; |
150 | } |
151 | } |
152 | |
153 | void VisitStmt_(const EvaluateNode* op) final { |
154 | scope_.push_back(StmtEntry()); |
155 | // visit subexpr |
156 | StmtExprVisitor::VisitStmt_(op); |
157 | StmtEntry e = scope_.back(); |
158 | scope_.pop_back(); |
159 | if (e.touched.size() != 0) { |
160 | e.stmt = op; |
161 | linear_seq_.push_back(e); |
162 | } |
163 | } |
164 | |
165 | void VisitExpr_(const CallNode* op) final { |
166 | if (op->op.same_as(builtin::address_of())) { |
167 | const BufferLoadNode* load = op->args[0].as<BufferLoadNode>(); |
168 | for (const auto& index : load->indices) { |
169 | this->VisitExpr(index); |
170 | } |
171 | } else { |
172 | StmtExprVisitor::VisitExpr_(op); |
173 | } |
174 | } |
175 | |
176 | void VisitExpr_(const VarNode* buf) final { |
177 | // Directly reference to the variable count as a read. |
178 | auto it = alloc_info_.find(buf); |
179 | if (it != alloc_info_.end() && it->second.alloc) { |
180 | ICHECK_LT(it->second.level, scope_.size()) << " buf=" << buf->name_hint; |
181 | scope_[it->second.level].touched.push_back(buf); |
182 | } |
183 | } |
184 | |
185 | template <typename T> |
186 | void VisitNewScope(const T* op) { |
187 | scope_.push_back(StmtEntry()); |
188 | StmtEntry e; |
189 | e.stmt = op; |
190 | int64_t begin_index = static_cast<int64_t>(linear_seq_.size()); |
191 | // before scope. |
192 | linear_seq_.push_back(e); |
193 | StmtExprVisitor::VisitStmt_(op); |
194 | // after scope. |
195 | e.touched = std::move(scope_.back().touched); |
196 | scope_.pop_back(); |
197 | int64_t end_index = static_cast<int64_t>(linear_seq_.size()); |
198 | ICHECK_GT(end_index, begin_index); |
199 | e.scope_pair_offset = begin_index - end_index; |
200 | linear_seq_.push_back(e); |
201 | // record the pointer to end index. |
202 | ICHECK_NE(end_index, 0U); |
203 | linear_seq_[begin_index].scope_pair_offset = end_index - begin_index; |
204 | } |
205 | |
206 | void VisitStmt_(const AttrStmtNode* op) final { |
207 | // Only record the outer most thread extent. |
208 | if (op->attr_key == attr::thread_extent && !in_thread_env_) { |
209 | in_thread_env_ = true; |
210 | VisitNewScope(op); |
211 | in_thread_env_ = false; |
212 | } else if (op->attr_key == attr::extern_scope) { |
213 | VisitNewScope(op); |
214 | } else if (op->attr_key == attr::virtual_thread) { |
215 | VisitNewScope(op); |
216 | } else { |
217 | StmtExprVisitor::VisitStmt_(op); |
218 | } |
219 | } |
220 | |
221 | void VisitStmt_(const IfThenElseNode* op) final { VisitNewScope(op); } |
222 | |
223 | void VisitStmt_(const ForNode* op) final { VisitNewScope(op); } |
224 | |
225 | void VisitStmt_(const WhileNode* op) final { VisitNewScope(op); } |
226 | |
227 | void VisitStmt_(const AssertStmtNode* op) final { VisitNewScope(op); } |
228 | |
229 | void VisitStmt_(const LetStmtNode* op) final { VisitNewScope(op); } |
230 | |
231 | // linearized access sequence. |
232 | std::vector<StmtEntry> linear_seq_; |
233 | // The storage scope of each buffer |
234 | std::unordered_map<const VarNode*, AllocEntry> alloc_info_; |
235 | |
236 | private: |
237 | // Whether already in thread env. |
238 | bool in_thread_env_{false}; |
239 | // The scope stack. |
240 | std::vector<StmtEntry> scope_; |
241 | }; |
242 | |
243 | // Verify if the statement can be run safely via inplace fashion |
244 | // |
245 | // Detect pattern: dst[index] = f(src[index]) |
246 | // |
247 | // WARNING: the current detection algorithm cannot handle the case |
248 | // when a location in an array is written multiple times |
249 | // |
250 | // For example, the following program will pass the check, |
251 | // but we cannot make A and B to be the same array. |
252 | // |
253 | // A[0] = B[0] + 1 |
254 | // A[0] = B[0] + 1 |
255 | // |
256 | // The high level code generator needs to ensure that the generated |
257 | // code only write each location of the target array once. |
258 | // |
259 | // This is the case with IR generated by the current compute schedule. |
260 | // We explicitly return false if we find there is an extern block |
261 | // which can be arbitrary IR. |
262 | // |
263 | // Neve-the-less, inplace detector should be used with care in mind. |
264 | // We may also consider introduce a condition checker that checks |
265 | // if every index only visited once for an absolute sufficient condition. |
266 | // |
267 | // The code after inplace transformation is no longer idempotent. |
268 | // |
269 | class InplaceOpVerifier : public StmtExprVisitor { |
270 | public: |
271 | bool Check(const Object* stmt, const VarNode* dst, const VarNode* src) { |
272 | dst_ = dst; |
273 | src_ = src; |
274 | result_ = true; |
275 | if (stmt->IsInstance<AttrStmtNode>()) { |
276 | VisitStmt_(static_cast<const AttrStmtNode*>(stmt)); |
277 | } else if (stmt->IsInstance<ForNode>()) { |
278 | VisitStmt_(static_cast<const ForNode*>(stmt)); |
279 | } else if (stmt->IsInstance<IfThenElseNode>()) { |
280 | VisitStmt_(static_cast<const IfThenElseNode*>(stmt)); |
281 | } else if (stmt->IsInstance<WhileNode>()) { |
282 | VisitStmt_(static_cast<const WhileNode*>(stmt)); |
283 | } else if (stmt->IsInstance<StoreNode>()) { |
284 | VisitStmt_(static_cast<const StoreNode*>(stmt)); |
285 | } else if (stmt->IsInstance<BufferStoreNode>()) { |
286 | VisitStmt_(static_cast<const BufferStoreNode*>(stmt)); |
287 | } else { |
288 | return false; |
289 | } |
290 | return result_; |
291 | } |
292 | |
293 | using StmtExprVisitor::VisitStmt_; |
294 | |
295 | void VisitStmt(const Stmt& n) final { |
296 | if (!result_) return; |
297 | StmtExprVisitor::VisitStmt(n); |
298 | } |
299 | void VisitExpr(const PrimExpr& n) final { |
300 | if (!result_) return; |
301 | StmtExprVisitor::VisitExpr(n); |
302 | } |
303 | |
304 | void VisitExpr_(const VarNode* op) final { |
305 | // assume all opaque access is unsafe |
306 | if (op == dst_ || op == src_) { |
307 | result_ = false; |
308 | return; |
309 | } |
310 | } |
311 | |
312 | void VisitStmt_(const StoreNode* op) final { |
313 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
314 | } |
315 | |
316 | void VisitStmt_(const BufferStoreNode* op) final { |
317 | ++mem_nest_; |
318 | for (const auto& index : op->indices) { |
319 | this->VisitExpr(index); |
320 | } |
321 | --mem_nest_; |
322 | if (op->buffer->data.get() == dst_) { |
323 | store_ = op; |
324 | this->VisitExpr(op->value); |
325 | store_ = nullptr; |
326 | } else { |
327 | this->VisitExpr(op->value); |
328 | } |
329 | } |
330 | |
331 | void VisitStmt_(const AttrStmtNode* op) final { |
332 | // always reject extern code |
333 | if (op->attr_key == attr::extern_scope || op->attr_key == attr::volatile_scope) { |
334 | result_ = false; |
335 | return; |
336 | } |
337 | StmtExprVisitor::VisitStmt_(op); |
338 | } |
339 | |
340 | void VisitExpr_(const LoadNode* op) final { |
341 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
342 | } |
343 | |
344 | void VisitExpr_(const BufferLoadNode* op) final { |
345 | const VarNode* buf = op->buffer->data.get(); |
346 | // cannot read from dst_ (no reduction) |
347 | if (buf == dst_) { |
348 | result_ = false; |
349 | return; |
350 | } |
351 | // do not allow indirect memory load |
352 | if (mem_nest_ != 0) { |
353 | result_ = false; |
354 | return; |
355 | } |
356 | if (src_ == buf) { |
357 | if (store_ == nullptr || store_->value.dtype() != op->dtype) { |
358 | result_ = false; |
359 | return; |
360 | } |
361 | ICHECK_EQ(store_->indices.size(), op->indices.size()) |
362 | << "Store/Load occur to the same buffer " << buf->name_hint |
363 | << " with differing number of indices" ; |
364 | for (size_t i = 0; i < store_->indices.size(); i++) { |
365 | if (!tir::ExprDeepEqual()(store_->indices[i], op->indices[i])) { |
366 | result_ = false; |
367 | return; |
368 | } |
369 | } |
370 | } |
371 | ++mem_nest_; |
372 | StmtExprVisitor::VisitExpr_(op); |
373 | --mem_nest_; |
374 | } |
375 | |
376 | private: |
377 | // result of the check |
378 | bool result_{true}; |
379 | // destination memory |
380 | const VarNode* dst_; |
381 | // source variable |
382 | const VarNode* src_; |
383 | // counter of load, |
384 | // it is not safe to inplace when there is nested load like A[B[i]] |
385 | int mem_nest_{0}; |
386 | // The current store to be inspected |
387 | const BufferStoreNode* store_{nullptr}; |
388 | }; |
389 | |
390 | /* \brief Rewrite and merge memory allocation. |
391 | * |
392 | * Using LinearAccessPatternFinder, determines which buffers could share an |
393 | * allocation. This includes both sequential usage of the same buffer and |
394 | * merging small allocations at the same scope into a single larger allocation. |
395 | * The merging of small allocations requires the codegen to cast the resulting |
396 | * value from the storage type to the output type after access. |
397 | */ |
398 | class StoragePlanRewriter : public StmtExprMutator { |
399 | public: |
400 | using StmtEntry = LinearAccessPatternFinder::StmtEntry; |
401 | using AllocEntry = LinearAccessPatternFinder::AllocEntry; |
402 | |
403 | Stmt Rewrite(Stmt stmt, bool detect_inplace) { |
404 | detect_inplace_ = detect_inplace; |
405 | // plan the rewrite |
406 | LinearAccessPatternFinder finder; |
407 | finder(stmt); |
408 | this->LivenessAnalysis(finder.linear_seq_); |
409 | this->PlanMemory(finder.linear_seq_, finder.alloc_info_); |
410 | this->PrepareNewAlloc(); |
411 | // start rewrite |
412 | stmt = operator()(std::move(stmt)); |
413 | if (attach_map_.count(nullptr)) { |
414 | return MakeAttach(attach_map_.at(nullptr), stmt); |
415 | } |
416 | return stmt; |
417 | } |
418 | |
419 | Stmt VisitStmt_(const StoreNode* op) final { |
420 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
421 | } |
422 | |
423 | PrimExpr VisitExpr_(const LoadNode* op) final { |
424 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
425 | } |
426 | |
427 | template <typename Node> |
428 | Node VisitBufferAccess(Node node) { |
429 | auto it = alloc_map_.find(node->buffer->data.get()); |
430 | if (it != alloc_map_.end()) { |
431 | Buffer buf = RemapBuffer(node->buffer, it->second->alloc_var); |
432 | |
433 | Array<PrimExpr> indices = node->indices; |
434 | indices.Set(indices.size() - 1, |
435 | RemapIndex(node->buffer->dtype, indices[indices.size() - 1], it->second)); |
436 | |
437 | auto writer = node.CopyOnWrite(); |
438 | writer->buffer = buf; |
439 | writer->indices = indices; |
440 | } |
441 | return node; |
442 | } |
443 | |
444 | Buffer RemapBuffer(Buffer buf, Var new_backing_array) { |
445 | auto key = buf.get(); |
446 | auto it = buffer_remap_.find(key); |
447 | if (it != buffer_remap_.end()) { |
448 | ICHECK_EQ(it->second->data.get(), new_backing_array.get()) |
449 | << "Cannot remap buffer " << buf->name << " to use backing array " |
450 | << new_backing_array->name_hint << ", previously used backing array " |
451 | << it->second->data->name_hint; |
452 | return it->second; |
453 | } |
454 | |
455 | Buffer remapped = Buffer(new_backing_array, buf->dtype, buf->shape, buf->strides, |
456 | buf->elem_offset, new_backing_array->name_hint, buf->data_alignment, |
457 | buf->offset_factor, buf->buffer_type, buf->axis_separators, buf->span); |
458 | buffer_remap_[key] = remapped; |
459 | return remapped; |
460 | } |
461 | |
462 | Stmt VisitStmt_(const BufferStoreNode* op) final { |
463 | auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); |
464 | return VisitBufferAccess(std::move(node)); |
465 | } |
466 | |
467 | PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
468 | auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); |
469 | return VisitBufferAccess(std::move(node)); |
470 | } |
471 | |
472 | PrimExpr VisitExpr_(const VarNode* op) final { |
473 | auto it = alloc_map_.find(op); |
474 | if (it != alloc_map_.end()) { |
475 | if (it->second->bits_offset != 0) { |
476 | LOG(WARNING) << "Use a merged buffer variable address, could cause error" ; |
477 | } |
478 | return it->second->alloc_var; |
479 | } else { |
480 | return GetRef<PrimExpr>(op); |
481 | } |
482 | } |
483 | PrimExpr VisitExpr_(const CallNode* op) final { |
484 | if (op->op.same_as(builtin::tvm_access_ptr())) { |
485 | ICHECK_EQ(op->args.size(), 5U); |
486 | DataType dtype = op->args[0].dtype(); |
487 | const VarNode* buffer = op->args[1].as<VarNode>(); |
488 | auto it = alloc_map_.find(buffer); |
489 | if (it == alloc_map_.end()) { |
490 | return StmtExprMutator::VisitExpr_(op); |
491 | } |
492 | const StorageEntry* se = it->second; |
493 | PrimExpr offset = this->VisitExpr(op->args[2]); |
494 | PrimExpr extent = this->VisitExpr(op->args[3]); |
495 | uint64_t elem_bits = dtype.bits() * dtype.lanes(); |
496 | ICHECK_EQ(se->bits_offset % elem_bits, 0U); |
497 | if (se->bits_offset != 0) { |
498 | offset = make_const(offset.dtype(), se->bits_offset / elem_bits) + offset; |
499 | } |
500 | return Call(op->dtype, op->op, {op->args[0], se->alloc_var, offset, extent, op->args[4]}); |
501 | } else { |
502 | return StmtExprMutator::VisitExpr_(op); |
503 | } |
504 | } |
505 | |
506 | Stmt VisitStmt_(const AttrStmtNode* op) final { |
507 | if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread || |
508 | attr::IsPragmaKey(op->attr_key)) { |
509 | // remake all the allocation at the attach scope. |
510 | if (attach_map_.count(op)) { |
511 | auto& svec = attach_map_[op]; |
512 | Stmt stmt = StmtExprMutator::VisitStmt_(op); |
513 | op = stmt.as<AttrStmtNode>(); |
514 | return AttrStmt(op->node, op->attr_key, op->value, MakeAttach(svec, op->body)); |
515 | } else { |
516 | return StmtExprMutator::VisitStmt_(op); |
517 | } |
518 | } else if (op->attr_key == attr::volatile_scope) { |
519 | Stmt stmt = StmtExprMutator::VisitStmt_(op); |
520 | op = stmt.as<AttrStmtNode>(); |
521 | auto it = alloc_map_.find(op->node.as<VarNode>()); |
522 | if (it == alloc_map_.end()) return stmt; |
523 | return AttrStmt(it->second->alloc_var, op->attr_key, op->value, op->body); |
524 | } else { |
525 | return StmtExprMutator::VisitStmt_(op); |
526 | } |
527 | } |
528 | |
529 | Stmt VisitStmt_(const ForNode* op) final { |
530 | ICHECK(op->kind != ForKind::kVectorized) << "VectorizeLoop before LiftStorageAlloc" ; |
531 | // remake all the allocation at the attach scope. |
532 | if (attach_map_.count(op)) { |
533 | auto& svec = attach_map_[op]; |
534 | Stmt stmt = StmtExprMutator::VisitStmt_(op); |
535 | op = stmt.as<ForNode>(); |
536 | return For(op->loop_var, op->min, op->extent, op->kind, MakeAttach(svec, op->body), |
537 | op->thread_binding, op->annotations); |
538 | } else { |
539 | return StmtExprMutator::VisitStmt_(op); |
540 | } |
541 | } |
542 | |
543 | Stmt VisitStmt_(const AllocateNode* op) final { return this->VisitStmt(op->body); } |
544 | |
545 | private: |
546 | struct StorageEntry { |
547 | // The scope that this alloc attaches after |
548 | // For shared/local memory it is beginning of the thread extent. |
549 | // for global memory it is nullptr, means beginning of everything. |
550 | const Object* attach_scope_{nullptr}; |
551 | // The constant size of the buffer in bits, only used if it is constant |
552 | uint64_t const_nbits{0}; |
553 | // The storage scope. |
554 | StorageScope scope; |
555 | // The physical dimensionality of the allocations. Since |
556 | // StorageRewrite is applied after StorageFlatten/FlattenBuffer, |
557 | // this is size of `AllocateNode::extents`. If moved |
558 | size_t ndim; |
559 | // Allocs that shares this entry. |
560 | std::vector<const AllocateNode*> allocs; |
561 | // The children of this entry, not including itself. |
562 | std::vector<StorageEntry*> merged_children; |
563 | // The replacement allocation, if any. |
564 | Stmt new_alloc; |
565 | // The var expr of new allocation. |
566 | Var alloc_var; |
567 | // The allocation element type. |
568 | DataType elem_type; |
569 | // This is non-zero if this allocate is folded into another one |
570 | // the address(in bits) becomes alloc_var + bits_offset; |
571 | // can be effectively converted to the element type. |
572 | // We need to convert bit_offset to offset of specific element type later. |
573 | // |
574 | // We use bits(instead of bytes) to support non-conventional indexing in hardware. |
575 | // When we are merging buffer together, the bits_offset are set to be aligned |
576 | // to certain value given by the max_simd_bits property of the special memory. |
577 | // |
578 | // This allows effective sharing among different types as long as their alignment |
579 | // requirement fits into the max_simd_bits. |
580 | uint64_t bits_offset{0}; |
581 | }; |
582 | |
583 | // Checks whether the storage_scope is especially tagged for a specific memory. |
584 | // Special memory is all combined into a single allocation. |
585 | bool IsSpecialTaggedMemory(const StorageScope& scope) { |
586 | return scope.tag.length() != 0 && scope.tag != ".dyn" && scope.tag != ".workspace" && |
587 | scope.tag != ".vtcm" ; |
588 | } |
589 | |
590 | // Alllocate entry of node. |
591 | // Event entry in liveness analysis |
592 | struct EventEntry { |
593 | // variables we generate |
594 | std::vector<const VarNode*> gen; |
595 | // variables we kill |
596 | std::vector<const VarNode*> kill; |
597 | }; |
598 | |
599 | Stmt MakeAttach(const std::vector<StorageEntry*>& svec, Stmt body) { |
600 | std::vector<Stmt> nest; |
601 | for (StorageEntry* e : svec) { |
602 | if (e->new_alloc.defined()) { |
603 | nest.push_back(e->new_alloc); |
604 | } |
605 | } |
606 | return MergeNest(nest, body); |
607 | } |
608 | // Remap the index |
609 | PrimExpr RemapIndex(DataType dtype, PrimExpr index, StorageEntry* e) { |
610 | if (e->bits_offset == 0) return index; |
611 | uint64_t elem_bits = dtype.bits(); |
612 | ICHECK_EQ(e->bits_offset % elem_bits, 0U); |
613 | return make_const(index.dtype(), e->bits_offset / elem_bits) + index; |
614 | } |
615 | // Prepare the new allocations |
616 | void PrepareNewAlloc() { |
617 | for (size_t i = 0; i < alloc_vec_.size(); ++i) { |
618 | StorageEntry* e = alloc_vec_[i].get(); |
619 | attach_map_[e->attach_scope_].push_back(e); |
620 | } |
621 | // find allocation via attach map. |
622 | for (auto& kv : attach_map_) { |
623 | // find the element with the most amount of bytes. |
624 | std::vector<StorageEntry*>& vec = kv.second; |
625 | // try to find merge, for tagged memory |
626 | for (size_t i = 0; i < vec.size(); ++i) { |
627 | StorageEntry* e = vec[i]; |
628 | if (IsSpecialTaggedMemory(e->scope)) { |
629 | ICHECK_NE(e->const_nbits, 0U) << "Special tagged memory must be const size" ; |
630 | for (size_t j = 0; j < i; ++j) { |
631 | if (e->scope == vec[j]->scope) { |
632 | vec[j]->merged_children.push_back(e); |
633 | break; |
634 | } |
635 | } |
636 | } |
637 | } |
638 | // Start allocation |
639 | for (size_t i = 0; i < vec.size(); ++i) { |
640 | StorageEntry* e = vec[i]; |
641 | // already merged |
642 | if (e->bits_offset != 0) continue; |
643 | if (e->merged_children.size() != 0) { |
644 | NewAllocTagMerged(e); |
645 | continue; |
646 | } |
647 | // Get the allocation size; |
648 | e->alloc_var = e->allocs[0]->buffer_var; |
649 | DataType alloc_type = e->allocs[0]->dtype; |
650 | for (const AllocateNode* op : e->allocs) { |
651 | if (op->dtype.lanes() > alloc_type.lanes()) { |
652 | alloc_type = op->dtype; |
653 | } |
654 | } |
655 | |
656 | bool all_allocs_identical = std::all_of( |
657 | e->allocs.begin() + 1, e->allocs.end(), [&](const AllocateNode* op) -> bool { |
658 | const AllocateNode* first = *e->allocs.begin(); |
659 | if (op->dtype != first->dtype) { |
660 | return false; |
661 | } |
662 | if (op->extents.size() != first->extents.size()) { |
663 | return false; |
664 | } |
665 | ExprDeepEqual expr_equal; |
666 | for (size_t i = 0; i < op->extents.size(); i++) { |
667 | if (!expr_equal(op->extents[i], first->extents[i])) { |
668 | return false; |
669 | } |
670 | } |
671 | return true; |
672 | }); |
673 | |
674 | if (all_allocs_identical) { |
675 | // simply use the original allocation. |
676 | e->new_alloc = Allocate(e->alloc_var, alloc_type, e->allocs[0]->extents, |
677 | e->allocs[0]->condition, Evaluate(0)); |
678 | if (IsSpecialTaggedMemory(e->scope)) { |
679 | MemoryInfo info = GetMemoryInfo(e->scope.to_string()); |
680 | if (info.defined()) { |
681 | uint64_t total_elem = e->const_nbits / e->elem_type.bits(); |
682 | ICHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits) |
683 | << "Allocation exceed bound of memory tag " << e->scope.to_string(); |
684 | } |
685 | } |
686 | } else { |
687 | // Build a merged allocation |
688 | PrimExpr combo_size; |
689 | for (const AllocateNode* op : e->allocs) { |
690 | ICHECK_EQ(op->extents.size(), 1) |
691 | << "Buffer var " << op->buffer_var->name_hint |
692 | << " was identified as a re-usable allocation, but has " << op->extents.size() |
693 | << " physical dimensions. " |
694 | << "Currently, only flat 1-d memory spaces should be identified as re-usable " |
695 | "allocations." ; |
696 | PrimExpr sz = op->extents[0]; |
697 | auto nbits = op->dtype.bits() * op->dtype.lanes(); |
698 | if (const auto* imm = sz.as<IntImmNode>()) { |
699 | if (imm->value > std::numeric_limits<int>::max() / nbits) { |
700 | LOG(WARNING) << "The allocation requires : " << imm->value << " * " << nbits |
701 | << " bits, which is greater than the maximum of" |
702 | " int32. The size is cast to int64." |
703 | << "\n" ; |
704 | sz = make_const(DataType::Int(64), imm->value); |
705 | } |
706 | } |
707 | // transform to bits |
708 | auto sz_nbits = sz * nbits; |
709 | if (combo_size.defined()) { |
710 | combo_size = max(combo_size, sz_nbits); |
711 | } else { |
712 | combo_size = sz_nbits; |
713 | } |
714 | } |
715 | // transform to alloc bytes |
716 | auto type_bits = alloc_type.bits() * alloc_type.lanes(); |
717 | bool divided = analyzer_.CanProve(indexmod(combo_size, type_bits) == 0); |
718 | combo_size = indexdiv(combo_size, type_bits); |
719 | // round up for can not divided |
720 | if (!divided) { |
721 | combo_size = combo_size + make_const(DataType::Int(32), 1); |
722 | } |
723 | combo_size = analyzer_.Simplify(combo_size); |
724 | e->new_alloc = |
725 | Allocate(e->alloc_var, alloc_type, {combo_size}, const_true(), Evaluate(0)); |
726 | if (IsSpecialTaggedMemory(e->scope)) { |
727 | MemoryInfo info = GetMemoryInfo(e->scope.to_string()); |
728 | if (info.defined()) { |
729 | uint64_t total_elem = e->const_nbits / e->elem_type.bits(); |
730 | ICHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits) |
731 | << "Allocation exceed bound of memory tag " << e->scope.to_string(); |
732 | } |
733 | } |
734 | } |
735 | } |
736 | } |
737 | } |
738 | // New allocation for merged data |
739 | void NewAllocTagMerged(StorageEntry* e) { |
740 | ICHECK_NE(e->scope.tag.length(), 0U); |
741 | // allocate with element type. |
742 | ICHECK_NE(e->const_nbits, 0U); |
743 | MemoryInfo info = GetMemoryInfo(e->scope.to_string()); |
744 | uint64_t total_bits = e->const_nbits; |
745 | // By default, align to 32 bits. |
746 | size_t align = 32; |
747 | if (info.defined()) { |
748 | align = info->max_simd_bits; |
749 | } |
750 | // Always align to max_simd_bits |
751 | // so we can remap types by keeping this property |
752 | if (total_bits % align != 0) { |
753 | total_bits += align - (total_bits % align); |
754 | } |
755 | e->alloc_var = e->allocs[0]->buffer_var; |
756 | for (StorageEntry* child : e->merged_children) { |
757 | ICHECK_NE(child->const_nbits, 0U); |
758 | ICHECK_NE(total_bits, 0U); |
759 | child->bits_offset = total_bits; |
760 | child->alloc_var = e->alloc_var; |
761 | total_bits += child->const_nbits; |
762 | if (total_bits % align != 0) { |
763 | total_bits += align - (total_bits % align); |
764 | } |
765 | } |
766 | uint64_t type_bits = e->elem_type.bits() * e->elem_type.lanes(); |
767 | PrimExpr alloc_size = |
768 | make_const(e->allocs[0]->extents[0].dtype(), (total_bits + type_bits - 1) / type_bits); |
769 | e->new_alloc = Allocate(e->alloc_var, e->elem_type, {alloc_size}, const_true(), Evaluate(0)); |
770 | if (info.defined()) { |
771 | ICHECK_LE(total_bits, info->max_num_bits) |
772 | << "Allocation exceed bound of memory tag " << e->scope.to_string(); |
773 | } |
774 | } |
775 | // Liveness analysis to find gen and kill point of each variable. |
776 | void LivenessAnalysis(const std::vector<StmtEntry>& seq) { |
777 | // find kill point, do a reverse linear scan. |
778 | std::unordered_set<const VarNode*> touched; |
779 | for (size_t i = seq.size(); i != 0; --i) { |
780 | const StmtEntry& s = seq[i - 1]; |
781 | for (const VarNode* buffer : s.touched) { |
782 | if (!touched.count(buffer)) { |
783 | touched.insert(buffer); |
784 | event_map_[s.stmt].kill.push_back(buffer); |
785 | } |
786 | } |
787 | } |
788 | // find gen point, do forward scan |
789 | touched.clear(); |
790 | for (size_t i = 0; i < seq.size(); ++i) { |
791 | int64_t offset = seq[i].scope_pair_offset; |
792 | if (offset < 0) continue; |
793 | const StmtEntry& s = seq[i + offset]; |
794 | for (const VarNode* buffer : s.touched) { |
795 | if (!touched.count(buffer)) { |
796 | touched.insert(buffer); |
797 | event_map_[s.stmt].gen.push_back(buffer); |
798 | } |
799 | } |
800 | } |
801 | } |
802 | void PlanNewScope(const Object* op) { |
803 | if (thread_scope_ != nullptr) { |
804 | ICHECK(thread_scope_ == op); |
805 | // erase all memory atatched to this scope. |
806 | for (auto it = const_free_map_.begin(); it != const_free_map_.end();) { |
807 | if (it->second->attach_scope_ == op) { |
808 | it = const_free_map_.erase(it); |
809 | } else { |
810 | ++it; |
811 | } |
812 | } |
813 | for (auto it = sym_free_list_.begin(); it != sym_free_list_.end();) { |
814 | if ((*it)->attach_scope_ == op) { |
815 | it = sym_free_list_.erase(it); |
816 | } else { |
817 | ++it; |
818 | } |
819 | } |
820 | thread_scope_ = nullptr; |
821 | } else { |
822 | thread_scope_ = op; |
823 | } |
824 | } |
825 | |
826 | // Memory plan algorithm |
827 | void PlanMemory(const std::vector<StmtEntry>& seq, |
828 | const std::unordered_map<const VarNode*, AllocEntry>& alloc_info) { |
829 | std::unordered_set<const VarNode*> inplace_flag; |
830 | |
831 | for (size_t i = 0; i < seq.size(); ++i) { |
832 | const StmtEntry& s = seq[i]; |
833 | auto it = event_map_.find(seq[i].stmt); |
834 | |
835 | // scope_pair_offset >= 0 means it is either |
836 | // - leaf stmt(offset = 0) |
837 | // - beginning of scope(offset < 0) |
838 | // In both cases, we need to handle the gen event correctly |
839 | if (it != event_map_.end() && seq[i].scope_pair_offset >= 0) { |
840 | // Inplace operation detection |
841 | // specially handle this |
842 | bool detect_inplace = detect_inplace_ && (it->second.gen.size() <= 2); |
843 | |
844 | for (const VarNode* var : it->second.gen) { |
845 | ICHECK(alloc_info.count(var)); |
846 | const AllocEntry& entry = alloc_info.at(var); |
847 | const AllocateNode* alloc = entry.alloc; |
848 | auto storage_scope = StorageScope::Create(GetPtrStorageScope(GetRef<Var>(var))); |
849 | StorageEntry* dst_entry = nullptr; |
850 | // inplace detection |
851 | if (detect_inplace) { |
852 | // only one inplace var for s.stmt |
853 | bool inplace_found = false; |
854 | for (const VarNode* src : it->second.kill) { |
855 | if (!inplace_flag.count(src) && alloc_map_.count(src)) { |
856 | InplaceOpVerifier visitor; |
857 | StorageEntry* src_entry = alloc_map_.at(src); |
858 | if (src_entry->scope == storage_scope && |
859 | src_entry->attach_scope_ == thread_scope_ && |
860 | src_entry->elem_type == alloc->dtype.element_of() && |
861 | visitor.Check(s.stmt, var, src)) { |
862 | uint64_t const_nbits = static_cast<uint64_t>(alloc->ConstantAllocationSize()) * |
863 | alloc->dtype.bits() * alloc->dtype.lanes(); |
864 | if (src_entry->const_nbits == const_nbits && !inplace_found) { |
865 | // successfully inplace |
866 | dst_entry = src_entry; |
867 | inplace_flag.insert(src); |
868 | inplace_found = true; |
869 | } |
870 | } |
871 | } |
872 | } |
873 | } |
874 | if (dst_entry == nullptr) { |
875 | dst_entry = |
876 | FindAlloc(alloc, thread_scope_, storage_scope, entry.num_physical_dimensions); |
877 | } |
878 | dst_entry->allocs.emplace_back(alloc); |
879 | alloc_map_[var] = dst_entry; |
880 | } |
881 | } |
882 | // enter/exit new scope |
883 | if (s.stmt->IsInstance<AttrStmtNode>()) { |
884 | const auto* op = static_cast<const AttrStmtNode*>(s.stmt); |
885 | if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread || |
886 | attr::IsPragmaKey(op->attr_key)) { |
887 | PlanNewScope(op); |
888 | } else { |
889 | ICHECK(op->attr_key == attr::extern_scope); |
890 | } |
891 | } else if (s.stmt->IsInstance<ForNode>()) { |
892 | const auto* op = static_cast<const ForNode*>(s.stmt); |
893 | if (op->kind == ForKind::kParallel) { |
894 | if (thread_scope_ == nullptr || thread_scope_ == op) { |
895 | PlanNewScope(op); |
896 | } |
897 | } |
898 | } |
899 | // scope_pair_offset <= 0 means it is either |
900 | // - leaf stmt(offset = 0) |
901 | // - end of scope(offset < 0) |
902 | // In both cases, we need to handle the kill event correctly |
903 | if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) { |
904 | for (const VarNode* var : it->second.kill) { |
905 | // skip space which are already replaced by inplace |
906 | if (!inplace_flag.count(var)) { |
907 | this->Free(var); |
908 | } |
909 | } |
910 | } |
911 | } |
912 | } |
913 | // Allocate new storage entry. |
914 | StorageEntry* NewAlloc(const AllocateNode* op, const Object* attach_scope, |
915 | const StorageScope& scope, size_t const_nbits) { |
916 | ICHECK(op != nullptr); |
917 | // Re-use not successful, allocate a new buffer. |
918 | auto entry = std::make_unique<StorageEntry>(); |
919 | entry->attach_scope_ = attach_scope; |
920 | entry->scope = scope; |
921 | entry->elem_type = op->dtype.element_of(); |
922 | entry->const_nbits = const_nbits; |
923 | StorageEntry* e = entry.get(); |
924 | alloc_vec_.emplace_back(std::move(entry)); |
925 | return e; |
926 | } |
927 | |
928 | StorageEntry* FindAlloc(const AllocateNode* op, const Object* attach_scope, |
929 | const StorageScope& scope, size_t num_physical_dimensions) { |
930 | ICHECK(op != nullptr); |
931 | // skip plan for local variable, |
932 | // compiler can do a better job with register allocation. |
933 | const uint64_t match_range = 16; |
934 | uint64_t op_elem_bits = op->dtype.bits() * op->dtype.lanes(); |
935 | uint64_t const_nbits = static_cast<uint64_t>(op->ConstantAllocationSize() * op_elem_bits); |
936 | |
937 | // If the size of the array isn't known at compile-time, it must |
938 | // have its own allocation with size determined at runtime. |
939 | bool is_known_size = (const_nbits != 0); |
940 | |
941 | // Currently, only flat memory spaces can be re-used. Packing |
942 | // into N-d space (e.g. 2-d texture memory on GPUs) will require |
943 | // more in-depth algorithms. |
944 | bool is_flat_memory_space = (num_physical_dimensions == 1); |
945 | |
946 | // disable reuse of small arrays, they will be lowered to registers in LLVM |
947 | // This rules only apply if we are using non special memory |
948 | bool is_small_array = |
949 | (scope.tag.length() == 0) && (scope.rank >= StorageRank::kWarp || op->dtype.is_handle() || |
950 | (is_known_size && const_nbits <= 32)); |
951 | |
952 | if (is_small_array || !is_flat_memory_space) { |
953 | return NewAlloc(op, attach_scope, scope, const_nbits); |
954 | } |
955 | |
956 | if (is_known_size) { |
957 | // constant allocation. |
958 | auto begin = const_free_map_.lower_bound(const_nbits / match_range); |
959 | auto mid = const_free_map_.lower_bound(const_nbits); |
960 | auto end = const_free_map_.upper_bound(const_nbits * match_range); |
961 | // start looking at the buffer that is bigger than the required size first |
962 | for (auto it = mid; it != end; ++it) { |
963 | StorageEntry* e = it->second; |
964 | if (e->attach_scope_ != attach_scope) continue; |
965 | if (e->scope != scope) continue; |
966 | // when not divided, no reuse, eg, float4 vs float3 |
967 | if (e->bits_offset % op_elem_bits != 0) continue; |
968 | e->const_nbits = std::max(const_nbits, e->const_nbits); |
969 | const_free_map_.erase(it); |
970 | return e; |
971 | } |
972 | // then start looking at smaller buffers. |
973 | for (auto it = mid; it != begin;) { |
974 | --it; |
975 | StorageEntry* e = it->second; |
976 | if (e->attach_scope_ != attach_scope) continue; |
977 | if (e->scope != scope) continue; |
978 | if (e->elem_type != op->dtype.element_of()) continue; |
979 | e->const_nbits = std::max(const_nbits, e->const_nbits); |
980 | const_free_map_.erase(it); |
981 | return e; |
982 | } |
983 | } else { |
984 | // Simple strategy: round roubin. |
985 | for (auto it = sym_free_list_.begin(); it != sym_free_list_.end(); ++it) { |
986 | StorageEntry* e = *it; |
987 | if (e->attach_scope_ != attach_scope) continue; |
988 | if (e->scope != scope) continue; |
989 | if (e->elem_type != op->dtype.element_of()) continue; |
990 | sym_free_list_.erase(it); |
991 | return e; |
992 | } |
993 | } |
994 | return NewAlloc(op, attach_scope, scope, const_nbits); |
995 | } |
996 | // simulated free. |
997 | void Free(const VarNode* var) { |
998 | auto it = alloc_map_.find(var); |
999 | ICHECK(it != alloc_map_.end()); |
1000 | StorageEntry* e = it->second; |
1001 | ICHECK_NE(e->allocs.size(), 0U); |
1002 | |
1003 | // disable reuse of small arrays, they will be lowered to registers in LLVM |
1004 | // This rules only apply if we are using non special memory |
1005 | if (e->scope.tag.length() == 0) { |
1006 | // Disable sharing of local memory. |
1007 | if (e->scope.rank >= StorageRank::kWarp || e->allocs[0]->dtype.is_handle()) return; |
1008 | // disable reuse of small arrays |
1009 | if (e->const_nbits > 0 && e->const_nbits <= 32) return; |
1010 | } |
1011 | // normal free. |
1012 | if (e->const_nbits != 0) { |
1013 | const_free_map_.insert({e->const_nbits, e}); |
1014 | } else { |
1015 | sym_free_list_.push_back(e); |
1016 | } |
1017 | } |
1018 | // thread scope. |
1019 | const Object* thread_scope_{nullptr}; |
1020 | // whether enable inplace detection. |
1021 | bool detect_inplace_{false}; |
1022 | // Locations of free ops. |
1023 | std::unordered_map<const Object*, EventEntry> event_map_; |
1024 | // constant size free map. |
1025 | std::multimap<uint64_t, StorageEntry*> const_free_map_; |
1026 | // symbolic free list, for non constant items. |
1027 | std::list<StorageEntry*> sym_free_list_; |
1028 | // The allocation attach map |
1029 | std::unordered_map<const Object*, std::vector<StorageEntry*>> attach_map_; |
1030 | // The allocation assign map |
1031 | std::unordered_map<const VarNode*, StorageEntry*> alloc_map_; |
1032 | // The allocations |
1033 | std::vector<std::unique_ptr<StorageEntry>> alloc_vec_; |
1034 | // The buffer objects being remapped |
1035 | std::unordered_map<const BufferNode*, Buffer> buffer_remap_; |
1036 | // analyzer |
1037 | arith::Analyzer analyzer_; |
1038 | }; |
1039 | |
1040 | /* Helper struct containing information on how a buffer is declared and used |
1041 | * |
1042 | */ |
1043 | struct BufferVarInfo { |
1044 | enum DeclarationLocation { |
1045 | kPrimFuncParam = (1 << 0), |
1046 | kPrimFuncBufferMap = (1 << 1), |
1047 | kAllocateNode = (1 << 2), |
1048 | kAllocateConstNode = (1 << 3), |
1049 | kLetNode = (1 << 4), |
1050 | }; |
1051 | |
1052 | // The tir::Var that represents this buffer. |
1053 | Var var; |
1054 | |
1055 | // The data type of an element of the buffer. |
1056 | DataType element_dtype; |
1057 | |
1058 | /* The extent of the buffer. |
1059 | * |
1060 | * If multidimensional, the extent of the last dimension of the buffer. If the |
1061 | * size is unknown (e.g. pointer arguments to PrimFunc with no corresponding |
1062 | * entry in buffer_map), then extent is zero. |
1063 | */ |
1064 | PrimExpr extent; |
1065 | |
1066 | // Where the buffer was declared |
1067 | DeclarationLocation declaration_location; |
1068 | |
1069 | // When accessed, which element type is it accessed as. This may |
1070 | // differ both in base type (e.g. int32* cast to float32* after |
1071 | // packing in StorageRewrite) or in number of lanes (e.g. float16* |
1072 | // cast to float16x4*). |
1073 | std::unordered_set<DataType> access_dtype; |
1074 | |
1075 | DataType get_preferred_dtype() const { |
1076 | std::unordered_set<DataType> base_access_dtype; |
1077 | for (auto dtype : access_dtype) { |
1078 | base_access_dtype.insert(dtype.element_of()); |
1079 | } |
1080 | // If the array is accessed as multiple base types within a |
1081 | // function, no point in changing the declared type. CodeGenC can |
1082 | // handle this with a type-cast prior to indexing. Vulkan will |
1083 | // raise an error at code-gen time, if a later pass doesn't split |
1084 | // it out. |
1085 | if (base_access_dtype.size() != 1) { |
1086 | return element_dtype; |
1087 | } |
1088 | |
1089 | DataType preferred_base_type = *base_access_dtype.begin(); |
1090 | |
1091 | // If there is only one vectorizable size used to access the |
1092 | // buffer, and if that access size is compatible with the array |
1093 | // size, then the buffer is vectorizable. In the future, this |
1094 | // could be improved to allow vectorized buffer access of size |
1095 | // GCD(*lanes_used), if necessary. |
1096 | int preferred_lanes = element_dtype.lanes(); |
1097 | if ((element_dtype.lanes() == 1) && (access_dtype.size() == 1)) { |
1098 | arith::Analyzer analyzer_; |
1099 | arith::ModularSet me = analyzer_.modular_set(extent); |
1100 | |
1101 | int lanes = access_dtype.begin()->lanes(); |
1102 | if ((me->coeff % lanes == 0) && (me->base % lanes == 0)) { |
1103 | preferred_lanes = lanes; |
1104 | } |
1105 | } |
1106 | |
1107 | return preferred_base_type.with_lanes(preferred_lanes); |
1108 | } |
1109 | }; |
1110 | |
1111 | /* Checks whether buffers are accessed as scalar or vector parameters in a |
1112 | * function. |
1113 | * |
1114 | */ |
1115 | class VectorTypeAccessChecker : public StmtExprVisitor { |
1116 | public: |
1117 | /* Constructor |
1118 | * |
1119 | * @param params The parameters passed to a PrimFunc |
1120 | * |
1121 | * @param buffer_map The buffer_map associated with a PrimFunc |
1122 | * |
1123 | * @param allow_untyped_handles If a buffer or pointer variable is |
1124 | * missing a type annotation, assume that it has the same underlying |
1125 | * type as it is later accessed, with scalar element types. |
1126 | */ |
1127 | VectorTypeAccessChecker(const Array<tir::Var>& params, const Map<Var, Buffer>& buffer_map, |
1128 | bool allow_untyped_pointers = false) |
1129 | : allow_untyped_pointers_(allow_untyped_pointers) { |
1130 | // If a parameter is in the buffer map, we want to track the |
1131 | // version in the map. |
1132 | for (auto it : buffer_map) { |
1133 | Buffer& buffer = it.second; |
1134 | Var buffer_var = buffer->data; |
1135 | DataType dtype = buffer->dtype; |
1136 | PrimExpr extent = buffer->shape.size() ? buffer->shape[buffer->shape.size() - 1] : 0; |
1137 | OnArrayDeclaration(buffer_var, dtype, extent, BufferVarInfo::kPrimFuncParam); |
1138 | } |
1139 | |
1140 | // If a pointer parameter isn't in the buffer map, then we want to |
1141 | // track the parameter itself. |
1142 | for (Var buffer_var : params) { |
1143 | auto pointer_type = GetPointerType(buffer_var->type_annotation); |
1144 | if (pointer_type.has_value() && (buffer_map.count(buffer_var) == 0)) { |
1145 | DataType dtype = pointer_type.value(); |
1146 | PrimExpr extent = 0; |
1147 | OnArrayDeclaration(buffer_var, dtype, extent, BufferVarInfo::kPrimFuncBufferMap); |
1148 | } |
1149 | } |
1150 | } |
1151 | |
1152 | void VisitExpr_(const LoadNode* op) final { |
1153 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
1154 | } |
1155 | |
1156 | void VisitStmt_(const StoreNode* op) final { |
1157 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
1158 | } |
1159 | |
1160 | void VisitExpr_(const BufferLoadNode* op) final { |
1161 | OnArrayAccess(op->dtype, op->buffer->data.get(), op->indices); |
1162 | StmtExprVisitor::VisitExpr_(op); |
1163 | } |
1164 | |
1165 | void VisitStmt_(const BufferStoreNode* op) final { |
1166 | OnArrayAccess(op->value.dtype(), op->buffer->data.get(), op->indices); |
1167 | StmtExprVisitor::VisitStmt_(op); |
1168 | } |
1169 | |
1170 | void VisitExpr_(const CallNode* op) final { |
1171 | if (op->op.same_as(builtin::tvm_access_ptr())) { |
1172 | DataType dtype = op->args[0].dtype(); |
1173 | const VarNode* buffer = op->args[1].as<VarNode>(); |
1174 | PrimExpr index = op->args[2]; |
1175 | OnArrayAccess(dtype, buffer, {index}); |
1176 | } |
1177 | StmtExprVisitor::VisitExpr_(op); |
1178 | } |
1179 | |
1180 | void VisitStmt_(const AllocateNode* op) final { |
1181 | const Array<PrimExpr>& extents = op->extents; |
1182 | PrimExpr extent = extents[extents.size() - 1]; |
1183 | OnArrayDeclaration(op->buffer_var, op->dtype, extent, BufferVarInfo::kAllocateNode); |
1184 | |
1185 | StmtExprVisitor::VisitStmt_(op); |
1186 | } |
1187 | |
1188 | void VisitStmt_(const AllocateConstNode* op) final { |
1189 | const Array<PrimExpr>& extents = op->extents; |
1190 | PrimExpr extent = extents.size() ? extents[extents.size() - 1] : NullValue<PrimExpr>(); |
1191 | OnArrayDeclaration(op->buffer_var, op->dtype, extent, BufferVarInfo::kAllocateConstNode); |
1192 | |
1193 | StmtExprVisitor::VisitStmt_(op); |
1194 | } |
1195 | |
1196 | void VisitExpr_(const LetNode* op) final { |
1197 | HandleLetNode(op->var); |
1198 | StmtExprVisitor::VisitExpr_(op); |
1199 | } |
1200 | |
1201 | void VisitStmt_(const LetStmtNode* op) final { |
1202 | HandleLetNode(op->var); |
1203 | StmtExprVisitor::VisitStmt_(op); |
1204 | } |
1205 | |
1206 | void HandleLetNode(Var let_var) { |
1207 | if (let_var->dtype.is_handle()) { |
1208 | auto pointer_type = GetPointerType(let_var->type_annotation); |
1209 | if (pointer_type.has_value()) { |
1210 | OnArrayDeclaration(let_var, pointer_type.value(), 0, BufferVarInfo::kLetNode); |
1211 | } else if (allow_untyped_pointers_) { |
1212 | OnArrayDeclaration(let_var, let_var->dtype, 0, BufferVarInfo::kLetNode); |
1213 | } else { |
1214 | LOG(FATAL) << "Let statement of variable " << let_var->name_hint |
1215 | << " is missing a type annotation, " |
1216 | << "or type annotation is not a pointer to primitive" ; |
1217 | } |
1218 | } |
1219 | } |
1220 | |
1221 | /* Update the type map for a buffer based on its declaration |
1222 | * |
1223 | * @param buffer The VarNode representing the buffer. |
1224 | * |
1225 | * @param element_dtype The dtype of a single element of the buffer. |
1226 | * If unknown, when used with the allow_untyped_handles option, |
1227 | * should be a handle dtype. |
1228 | * |
1229 | * @param extent The extent of the buffer. Zero if size is unknown. |
1230 | * |
1231 | * @param declaration_location How the buffer was allocated, so that |
1232 | * some locations can be rewritten without others. |
1233 | */ |
1234 | void OnArrayDeclaration(Var buffer, DataType element_dtype, PrimExpr extent, |
1235 | BufferVarInfo::DeclarationLocation declaration_location) { |
1236 | ICHECK(info_map_.find(buffer.get()) == info_map_.end()) |
1237 | << "Array declaration of " << buffer->name_hint << " occurred multiple times." ; |
1238 | |
1239 | if (element_dtype == DataType::Bool()) { |
1240 | element_dtype = DataType::Int(8).with_lanes(element_dtype.lanes()); |
1241 | } |
1242 | |
1243 | info_map_[buffer.get()] = {buffer, element_dtype, extent, declaration_location}; |
1244 | } |
1245 | |
1246 | /* Update the type map for a buffer based on its usage |
1247 | * |
1248 | * @param value_dtype The dtype of the value being stored to or |
1249 | * loaded from the buffer. |
1250 | * |
1251 | * @param buffer The VarNode representing the buffer. |
1252 | * |
1253 | * @param index The index at which the value is being stored/loaded. |
1254 | * |
1255 | * @param predicate The predicate used for the store/load. |
1256 | */ |
1257 | void OnArrayAccess(DataType value_dtype, const VarNode* buffer, const Array<PrimExpr>& indices) { |
1258 | auto it = info_map_.find(buffer); |
1259 | ICHECK(it != info_map_.end()) << "Load/Store of buffer " << buffer->name_hint << " (" << buffer |
1260 | << ") occurred before its declaration." ; |
1261 | BufferVarInfo& var_info = it->second; |
1262 | |
1263 | if (value_dtype.element_of() == DataType::Bool()) { |
1264 | value_dtype = DataType::Int(8).with_lanes(value_dtype.lanes()); |
1265 | } |
1266 | |
1267 | if (var_info.element_dtype.is_handle()) { |
1268 | ICHECK(allow_untyped_pointers_) << "Variable " << buffer->name_hint |
1269 | << " was missing a type annotation in its declaration" ; |
1270 | var_info.element_dtype = value_dtype.element_of(); |
1271 | } |
1272 | |
1273 | for (int i = 0; i < static_cast<int>(indices.size()) - 1; i++) { |
1274 | ICHECK(indices[i].dtype().is_scalar()) |
1275 | << "Only the last index of a buffer access may be a vector type." ; |
1276 | } |
1277 | int index_lanes = indices.size() ? indices.back().dtype().lanes() : 1; |
1278 | |
1279 | DataType access_dtype = value_dtype; |
1280 | |
1281 | int lanes_used = var_info.element_dtype.lanes(); |
1282 | |
1283 | // This can happen due to a previous pass that had rewrite_store_load = |
1284 | // false. This occurs from the StorageRewrite in tvm::lower, followed by the |
1285 | // PointerValueTypeRewrite in BuildSPIRV. The rewrite_store_load = false is |
1286 | // necessary because the C-based codegens do not yet support vectorized |
1287 | // pointer types (e.g. float16x4*). Once they do, this if statement should |
1288 | // instead be replaced by the below ICHECK_EQ. |
1289 | if (index_lanes * var_info.element_dtype.lanes() != value_dtype.lanes()) { |
1290 | ICHECK_EQ(index_lanes, value_dtype.lanes()); |
1291 | lanes_used = 1; |
1292 | var_info.element_dtype = var_info.element_dtype.with_lanes(1); |
1293 | } |
1294 | |
1295 | // TODO(Lunderberg): Uncomment this check once it can be applied. |
1296 | // See https://discuss.tvm.apache.org/t/pre-rfc-vectorized-tir-buffers/10615 |
1297 | // for discussion. |
1298 | |
1299 | // ICHECK_EQ(index_lanes * var_info.element_dtype.lanes(), value_dtype.lanes()) |
1300 | // << "Attempting to retrieve " << value_dtype.lanes() << " lanes of data with " |
1301 | // << index_lanes << " indices into an array whose elements have " |
1302 | // << var_info.element_dtype.lanes() << " lanes. " |
1303 | // << "Expected output with " << index_lanes * var_info.element_dtype.lanes() |
1304 | // << " lanes."; |
1305 | |
1306 | // If the index is a RampNode with stride of 1 and offset |
1307 | // divisible by the number of number of lanes, and the predicate |
1308 | // does not apply any masking, then this array access could be |
1309 | // vectorized. |
1310 | if (indices.size()) { |
1311 | const RampNode* ramp_index = indices[indices.size() - 1].as<RampNode>(); |
1312 | if (ramp_index && is_one(ramp_index->stride)) { |
1313 | arith::ModularSet me = analyzer_.modular_set(ramp_index->base); |
1314 | if ((me->coeff % ramp_index->lanes == 0) && (me->base % ramp_index->lanes == 0)) { |
1315 | lanes_used = ramp_index->lanes; |
1316 | } |
1317 | } |
1318 | } |
1319 | |
1320 | var_info.access_dtype.insert(access_dtype.with_lanes(lanes_used)); |
1321 | } |
1322 | |
1323 | // Map of buffer variable information determined |
1324 | std::unordered_map<const VarNode*, BufferVarInfo> info_map_; |
1325 | |
1326 | // |
1327 | bool allow_untyped_pointers_{false}; |
1328 | |
1329 | // internal analyzer |
1330 | arith::Analyzer analyzer_; |
1331 | }; |
1332 | |
1333 | /* \brief Rewrites buffer/pointer variables from scalar types to vectorized |
1334 | * types. |
1335 | * |
1336 | * Some runtimes do not allow casting between composite types and the underlying |
1337 | * base type (e.g. Vulkan, casting from 1-lane float16* to 4-lane float16x4*). |
1338 | * In these cases, in order to have vectorized load/store on an array, the |
1339 | * element type of that array must be vectorized. This is in contrast to C-style |
1340 | * runtimes, in which `float16x4* vec = *(float16x4*)(float_arr + offset)` is |
1341 | * valid. |
1342 | * |
1343 | * By default, VectorTypeRewriter will attempt to rewrite all buffer variables to |
1344 | * vectorized access, if the load/store occurring in the PrimFunc are all |
1345 | * vectorized. This includes adjusting the indices being used to access the |
1346 | * array. (e.g. If `float16* scalar_arr` is being converted to `float16x4* |
1347 | * vec_arr`, then `scalar_arr[Ramp(offset, 1, 4)]` will be converted to |
1348 | * `vec_arr[offset/4]`.) |
1349 | * |
1350 | * Currently, several of the C-style runtimes do not support buffers whose |
1351 | * elements are vectorized types, or rely on the presence of the Ramp nodes to |
1352 | * identify vectorized loads. The boolean parameters in the constructor are to |
1353 | * mimic the previous behavior of VectorTypeRewriter, to avoid breaking these |
1354 | * runtimes. Once all runtimes support vectorized buffer elements, these |
1355 | * parameters can be removed. |
1356 | */ |
1357 | class VectorTypeRewriter : public StmtExprMutator { |
1358 | public: |
1359 | /* Constructor |
1360 | * |
1361 | * @param checker The VectorTypeAccessChecker that has previously read out |
1362 | * information from the PrimFunc |
1363 | * |
1364 | * @param rewrite_params Whether pointer-type parameters passed into the |
1365 | * function should be rewritten from scalar types to vectorized types. |
1366 | * |
1367 | * @param rewrite_buffer_map Whether buffers present in the buffer_map should |
1368 | * have their data variable be rewritten from scalar types to vectorized types. |
1369 | * |
1370 | * @param rewrite_allocate_node Whether the buffer variable associated with |
1371 | * AllocateNodes should be rewritten from scalar types to vectorized types. |
1372 | * |
1373 | * @param rewrite_indices Whether the indices to the Load and Store nodes |
1374 | * should be rewritten to correspond to the new buffer_var type. |
1375 | * |
1376 | * @param rewrite_let_node Whether pointer declarations in let nodes |
1377 | * should be re-written. |
1378 | */ |
1379 | VectorTypeRewriter(const std::unordered_map<const VarNode*, BufferVarInfo>& info_map, |
1380 | bool rewrite_params = true, bool rewrite_buffer_map = true, |
1381 | bool rewrite_allocate_node = true, bool rewrite_indices = true, |
1382 | bool rewrite_let_node = true, bool rewrite_allocate_const_node = true) |
1383 | : rewrite_indices_(rewrite_indices) { |
1384 | int rewrite_mask = 0; |
1385 | if (rewrite_params) { |
1386 | rewrite_mask |= BufferVarInfo::kPrimFuncParam; |
1387 | } |
1388 | if (rewrite_buffer_map) { |
1389 | rewrite_mask |= BufferVarInfo::kPrimFuncBufferMap; |
1390 | } |
1391 | if (rewrite_allocate_node) { |
1392 | rewrite_mask |= BufferVarInfo::kAllocateNode; |
1393 | } |
1394 | if (rewrite_let_node) { |
1395 | rewrite_mask |= BufferVarInfo::kLetNode; |
1396 | } |
1397 | if (rewrite_allocate_const_node) { |
1398 | rewrite_mask |= BufferVarInfo::kAllocateConstNode; |
1399 | } |
1400 | |
1401 | // Rewrite any buffer variables whose preferred type isn't their current type. |
1402 | for (const auto& pair : info_map) { |
1403 | const auto& var_info = pair.second; |
1404 | DataType preferred = var_info.get_preferred_dtype(); |
1405 | if (preferred != var_info.element_dtype && (rewrite_mask & var_info.declaration_location)) { |
1406 | Var old_buffer_var = var_info.var; |
1407 | Var new_buffer_var(old_buffer_var->name_hint, |
1408 | PointerType(PrimType(preferred), GetPtrStorageScope(old_buffer_var)), |
1409 | old_buffer_var->span); |
1410 | |
1411 | rewrite_map_[var_info.var.get()] = {var_info.var, new_buffer_var, var_info.element_dtype, |
1412 | preferred}; |
1413 | } |
1414 | } |
1415 | } |
1416 | |
1417 | PrimExpr VisitExpr_(const LoadNode* op) final { |
1418 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
1419 | } |
1420 | |
1421 | Stmt VisitStmt_(const StoreNode* op) final { |
1422 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
1423 | } |
1424 | |
1425 | template <typename Node> |
1426 | Node VisitBufferAccess(Node node) { |
1427 | if (!rewrite_indices_) { |
1428 | return node; |
1429 | } |
1430 | |
1431 | auto it = rewrite_map_.find(node->buffer->data.get()); |
1432 | if (it == rewrite_map_.end()) { |
1433 | return node; |
1434 | } |
1435 | const auto& info = it->second; |
1436 | |
1437 | Array<PrimExpr> indices = node->indices; |
1438 | |
1439 | const RampNode* ramp_index = indices[indices.size() - 1].as<RampNode>(); |
1440 | if (ramp_index && is_one(ramp_index->stride)) { |
1441 | PrimExpr new_index = |
1442 | ramp_index->base / make_const(ramp_index->base.dtype(), ramp_index->lanes); |
1443 | if (ramp_index->lanes != info.factor()) { |
1444 | new_index = Ramp(new_index, ramp_index->stride, ramp_index->lanes / info.factor(), |
1445 | ramp_index->span); |
1446 | } |
1447 | |
1448 | indices.Set(indices.size() - 1, new_index); |
1449 | } |
1450 | |
1451 | auto writer = node.CopyOnWrite(); |
1452 | writer->buffer = RemapBuffer(node->buffer); |
1453 | writer->indices = indices; |
1454 | |
1455 | return node; |
1456 | } |
1457 | |
1458 | PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
1459 | auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); |
1460 | auto modified = VisitBufferAccess(node); |
1461 | |
1462 | // Not needed for BufferStoreNode, so we can't just call |
1463 | // LegalizeDtype() in VisitBufferAccess. |
1464 | if (node.same_as(modified)) { |
1465 | return std::move(node); |
1466 | } else { |
1467 | auto writer = modified.CopyOnWrite(); |
1468 | writer->LegalizeDType(); |
1469 | return std::move(modified); |
1470 | } |
1471 | } |
1472 | |
1473 | Stmt VisitStmt_(const BufferStoreNode* op) final { |
1474 | auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); |
1475 | return VisitBufferAccess(std::move(node)); |
1476 | } |
1477 | |
1478 | Stmt VisitStmt_(const LetStmtNode* op) final { |
1479 | auto it = rewrite_map_.find(op->var.get()); |
1480 | PrimExpr value = this->VisitExpr(op->value); |
1481 | Stmt body = this->VisitStmt(op->body); |
1482 | Var var = (it == rewrite_map_.end()) ? op->var : it->second.new_buffer_var; |
1483 | if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { |
1484 | return GetRef<Stmt>(op); |
1485 | } |
1486 | return LetStmt(var, value, body); |
1487 | } |
1488 | |
1489 | Buffer RemapBuffer(Buffer buf) { |
1490 | auto cache_key = buf.get(); |
1491 | |
1492 | auto cache_it = buffer_map_.find(cache_key); |
1493 | if (cache_it != buffer_map_.end()) { |
1494 | return cache_it->second; |
1495 | } |
1496 | |
1497 | auto info_it = rewrite_map_.find(buf->data.get()); |
1498 | if (info_it != rewrite_map_.end()) { |
1499 | auto& info = info_it->second; |
1500 | |
1501 | Array<PrimExpr> shape = buf->shape; |
1502 | PrimExpr last_dim = shape[shape.size() - 1]; |
1503 | shape.Set(shape.size() - 1, last_dim / make_const(last_dim.dtype(), info.factor())); |
1504 | |
1505 | auto writer = buf.CopyOnWrite(); |
1506 | writer->data = info.new_buffer_var; |
1507 | writer->dtype = info.new_element_dtype; |
1508 | writer->shape = shape; |
1509 | } |
1510 | |
1511 | buffer_map_[cache_key] = buf; |
1512 | return buf; |
1513 | } |
1514 | |
1515 | PrimExpr VisitExpr_(const CallNode* op) final { |
1516 | if (op->op.same_as(builtin::tvm_access_ptr())) { |
1517 | PrimExpr expr = StmtExprMutator::VisitExpr_(op); |
1518 | op = expr.as<CallNode>(); |
1519 | |
1520 | if (!rewrite_indices_) { |
1521 | return expr; |
1522 | } |
1523 | |
1524 | const VarNode* buffer_var = op->args[1].as<VarNode>(); |
1525 | auto it = rewrite_map_.find(buffer_var); |
1526 | if (it == rewrite_map_.end()) { |
1527 | return expr; |
1528 | } |
1529 | const auto& info = it->second; |
1530 | |
1531 | PrimExpr index = op->args[2]; |
1532 | PrimExpr extent = op->args[3]; |
1533 | PrimExpr flag = op->args[4]; |
1534 | |
1535 | PrimExpr e_dtype = tir::TypeAnnotation(info.new_element_dtype); |
1536 | int factor = info.factor(); |
1537 | extent = extent / make_const(extent.dtype(), factor); |
1538 | index = index / make_const(index.dtype(), factor); |
1539 | Array<PrimExpr> acc_args{e_dtype, info.new_buffer_var, index, extent, flag}; |
1540 | return Call(info.new_element_dtype, builtin::tvm_access_ptr(), acc_args); |
1541 | |
1542 | } else { |
1543 | return StmtExprMutator::VisitExpr_(op); |
1544 | } |
1545 | } |
1546 | |
1547 | Stmt VisitStmt_(const AllocateNode* op) final { |
1548 | Stmt stmt = StmtExprMutator::VisitStmt_(op); |
1549 | op = stmt.as<AllocateNode>(); |
1550 | |
1551 | auto it = rewrite_map_.find(op->buffer_var.get()); |
1552 | if (it == rewrite_map_.end()) { |
1553 | return stmt; |
1554 | } |
1555 | |
1556 | const auto& info = it->second; |
1557 | |
1558 | Var new_buffer_var = info.new_buffer_var; |
1559 | |
1560 | Array<PrimExpr> extents = op->extents; |
1561 | PrimExpr last_extent = extents[extents.size() - 1]; |
1562 | extents.Set(extents.size() - 1, last_extent / make_const(last_extent.dtype(), info.factor())); |
1563 | return Allocate(new_buffer_var, info.new_element_dtype, extents, op->condition, op->body); |
1564 | } |
1565 | |
1566 | Stmt VisitStmt_(const AllocateConstNode* op) final { |
1567 | Stmt stmt = StmtExprMutator::VisitStmt_(op); |
1568 | op = stmt.as<AllocateConstNode>(); |
1569 | |
1570 | auto it = rewrite_map_.find(op->buffer_var.get()); |
1571 | if (it == rewrite_map_.end()) { |
1572 | return stmt; |
1573 | } |
1574 | |
1575 | const auto& info = it->second; |
1576 | |
1577 | Var new_buffer_var = info.new_buffer_var; |
1578 | |
1579 | int factor = info.new_element_dtype.lanes() / op->dtype.lanes(); |
1580 | |
1581 | Array<PrimExpr> extents = op->extents; |
1582 | extents.Set(extents.size() - 1, |
1583 | extents[extents.size() - 1] / make_const(extents[0].dtype(), factor)); |
1584 | return AllocateConst(new_buffer_var, info.new_element_dtype, extents, op->data, op->body); |
1585 | } |
1586 | |
1587 | /* Update the parameters and all remaining variable references |
1588 | * |
1589 | * Should be called after calling operator() on the body of the |
1590 | * function. |
1591 | * |
1592 | * @param func A pointer to the PrimFunc being modified. |
1593 | */ |
1594 | void Finalize(PrimFunc* func_ptr) { |
1595 | ICHECK(func_ptr) << "Finalize expects a non-null pointer" ; |
1596 | auto& func = *func_ptr; |
1597 | auto* n = func.CopyOnWrite(); |
1598 | |
1599 | // Remap any remaining references to the old buffer variables |
1600 | Map<Var, PrimExpr> var_remap; |
1601 | for (const auto& pair : rewrite_map_) { |
1602 | const auto& info = pair.second; |
1603 | var_remap.Set(info.old_buffer_var, info.new_buffer_var); |
1604 | } |
1605 | n->body = Substitute(n->body, var_remap); |
1606 | |
1607 | // Remap the argument list to use the new buffer variables. |
1608 | Array<Var> new_params; |
1609 | for (const auto& old_param : n->params) { |
1610 | auto it = rewrite_map_.find(old_param.get()); |
1611 | if (it == rewrite_map_.end()) { |
1612 | new_params.push_back(old_param); |
1613 | } else { |
1614 | const auto& info = it->second; |
1615 | new_params.push_back(info.new_buffer_var); |
1616 | } |
1617 | } |
1618 | n->params = new_params; |
1619 | |
1620 | // Remap the Buffer objects in PrimFunc::buffer_map so that the |
1621 | // buffers use the new buffer variables |
1622 | Map<Var, Buffer> new_buffer_map; |
1623 | for (const auto& pair : n->buffer_map) { |
1624 | Var key = pair.first; |
1625 | Buffer old_buffer = pair.second; |
1626 | Var old_var = old_buffer->data; |
1627 | Buffer new_buffer = RemapBuffer(old_buffer); |
1628 | new_buffer_map.Set(key, new_buffer); |
1629 | } |
1630 | n->buffer_map = new_buffer_map; |
1631 | } |
1632 | |
1633 | private: |
1634 | struct RewriteInfo { |
1635 | Var old_buffer_var; |
1636 | Var new_buffer_var; |
1637 | DataType old_element_dtype; |
1638 | DataType new_element_dtype; |
1639 | |
1640 | int factor() const { |
1641 | int old_lanes = old_element_dtype.lanes(); |
1642 | int new_lanes = new_element_dtype.lanes(); |
1643 | ICHECK_EQ(new_lanes % old_lanes, 0); |
1644 | return new_lanes / old_lanes; |
1645 | } |
1646 | }; |
1647 | |
1648 | bool rewrite_indices_{true}; |
1649 | std::unordered_map<const VarNode*, RewriteInfo> rewrite_map_; |
1650 | std::unordered_map<const BufferNode*, Buffer> buffer_map_; |
1651 | }; |
1652 | |
1653 | // Rewrite allocates, pointer parameters, and buffer map into vectorized versions |
1654 | // if each access into a buffer is the same vector type. |
1655 | PrimFunc PointerValueTypeRewrite(PrimFunc f, bool allow_untyped_pointers = false, |
1656 | bool rewrite_params = true, bool rewrite_buffer_map = true, |
1657 | bool rewrite_allocate_node = true, bool rewrite_indices = true, |
1658 | bool rewrite_let_node = true, |
1659 | bool rewrite_allocate_const_node = true) { |
1660 | VectorTypeAccessChecker checker(f->params, f->buffer_map, allow_untyped_pointers); |
1661 | checker(f->body); |
1662 | |
1663 | VectorTypeRewriter rewriter(checker.info_map_, rewrite_params, rewrite_buffer_map, |
1664 | rewrite_allocate_node, rewrite_indices, rewrite_let_node, |
1665 | rewrite_allocate_const_node); |
1666 | PrimFuncNode* n = f.CopyOnWrite(); |
1667 | n->body = rewriter(std::move(n->body)); |
1668 | rewriter.Finalize(&f); |
1669 | |
1670 | return f; |
1671 | } |
1672 | |
1673 | namespace transform { |
1674 | |
1675 | Pass StorageRewrite() { |
1676 | auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { |
1677 | auto* n = f.CopyOnWrite(); |
1678 | n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true); |
1679 | // Parameters may not be rewritten, but internal allocations may. |
1680 | // Vectorization of AllocateConst is currently disabled, as it has |
1681 | // indexing issues for types that include padding (e.g. int8x3 |
1682 | // padded out to 32 bits) would require either rewriting |
1683 | // AllocateConst::data, or would require the code generators to |
1684 | // handle vectorized constants. |
1685 | return PointerValueTypeRewrite(std::move(f), true, false, false, true, true, true, false); |
1686 | }; |
1687 | return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite" , {}); |
1688 | } |
1689 | |
1690 | TVM_REGISTER_GLOBAL("tir.transform.StorageRewrite" ).set_body_typed(StorageRewrite); |
1691 | |
1692 | Pass PointerValueTypeRewrite() { |
1693 | auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { |
1694 | return PointerValueTypeRewrite(std::move(f)); |
1695 | }; |
1696 | return CreatePrimFuncPass(pass_func, 0, "tir.PointerValueTypeRewrite" , {}); |
1697 | } |
1698 | |
1699 | TVM_REGISTER_GLOBAL("tir.transform.PointerValueTypeRewrite" ) |
1700 | .set_body_typed(PointerValueTypeRewrite); |
1701 | |
1702 | } // namespace transform |
1703 | |
1704 | } // namespace tir |
1705 | } // namespace tvm |
1706 | |