1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file update_pointer_storage_scope.cc |
22 | * \brief A pass to update storage scopes for buffer variables. |
23 | */ |
24 | #include "update_pointer_storage_scope.h" |
25 | |
26 | #include <tvm/tir/expr.h> |
27 | #include <tvm/tir/op.h> |
28 | #include <tvm/tir/stmt_functor.h> |
29 | #include <tvm/tir/transform.h> |
30 | |
31 | #include <unordered_map> |
32 | |
33 | #include "../../runtime/thread_storage_scope.h" |
34 | #include "ir_utils.h" |
35 | |
36 | namespace tvm { |
37 | namespace tir { |
38 | |
39 | Var WithStorageScope(const VarNode* buffer_var, String storage_scope) { |
40 | auto* ptr_type = buffer_var->type_annotation.as<PointerTypeNode>(); |
41 | ICHECK(ptr_type) << "The provided variable is not of pointer type" ; |
42 | return Var(buffer_var->name_hint, PointerType(ptr_type->element_type, storage_scope), |
43 | buffer_var->span); |
44 | } |
45 | |
46 | UpdatePointerStorageScope::UpdatePointerStorageScope( |
47 | const std::unordered_map<const VarNode*, String>& new_storage_scopes) { |
48 | for (auto& kv : new_storage_scopes) { |
49 | new_var_remap_[kv.first] = WithStorageScope(kv.first, kv.second); |
50 | } |
51 | } |
52 | |
53 | PrimExpr UpdatePointerStorageScope::VisitExpr_(const VarNode* op) { |
54 | auto it = new_var_remap_.find(op); |
55 | if (it == new_var_remap_.end()) { |
56 | return GetRef<Var>(op); |
57 | } |
58 | return it->second; |
59 | } |
60 | |
61 | Stmt UpdatePointerStorageScope::VisitStmt_(const AllocateNode* op) { |
62 | auto remapped = Downcast<Var>(StmtExprMutator::VisitExpr(op->buffer_var)); |
63 | return Allocate(remapped, op->dtype, op->extents, StmtExprMutator::VisitExpr(op->condition), |
64 | StmtExprMutator::VisitStmt(op->body)); |
65 | } |
66 | |
67 | template <typename Node> |
68 | Node UpdatePointerStorageScope::UpdateBufferAccess(Node node) { |
69 | auto new_buffer = GetUpdatedBuffer(node->buffer); |
70 | if (!new_buffer.same_as(node->buffer)) { |
71 | auto writer = node.CopyOnWrite(); |
72 | writer->buffer = new_buffer; |
73 | } |
74 | return node; |
75 | } |
76 | |
77 | Buffer UpdatePointerStorageScope::GetUpdatedBuffer(Buffer buf) { |
78 | // Use the cached buffer, if it exists. |
79 | auto key = buf.get(); |
80 | auto it = new_buffer_remap_.find(key); |
81 | if (it != new_buffer_remap_.end()) { |
82 | return it->second; |
83 | } |
84 | |
85 | // Update the buffer's var, if needed. |
86 | auto remapped = Downcast<Var>(StmtExprMutator::VisitExpr(buf->data)); |
87 | if (!remapped.same_as(buf->data)) { |
88 | auto writer = buf.CopyOnWrite(); |
89 | writer->data = remapped; |
90 | } |
91 | |
92 | // Update the cache and return |
93 | new_buffer_remap_[key] = buf; |
94 | return buf; |
95 | } |
96 | |
97 | PrimExpr UpdatePointerStorageScope::VisitExpr_(const LoadNode* op) { |
98 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
99 | } |
100 | |
101 | PrimExpr UpdatePointerStorageScope::VisitExpr_(const BufferLoadNode* op) { |
102 | auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); |
103 | return UpdateBufferAccess(node); |
104 | } |
105 | |
106 | Stmt UpdatePointerStorageScope::VisitStmt_(const StoreNode* op) { |
107 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
108 | } |
109 | |
110 | Stmt UpdatePointerStorageScope::VisitStmt_(const BufferStoreNode* op) { |
111 | auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); |
112 | return UpdateBufferAccess(node); |
113 | } |
114 | |
115 | } // namespace tir |
116 | } // namespace tvm |
117 | |