1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file storage_access.cc |
22 | */ |
23 | #include "storage_access.h" |
24 | |
25 | #include <tvm/target/target_info.h> |
26 | #include <tvm/tir/op.h> |
27 | |
28 | #include <string> |
29 | #include <utility> |
30 | |
31 | #include "ir_utils.h" |
32 | |
33 | namespace tvm { |
34 | namespace tir { |
35 | |
36 | void StorageAccessVisitor::VisitExpr_(const LoadNode* op) { |
37 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
38 | } |
39 | |
40 | void StorageAccessVisitor::VisitStmt_(const StoreNode* op) { |
41 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
42 | } |
43 | |
44 | void StorageAccessVisitor::VisitExpr_(const BufferLoadNode* op) { |
45 | Var buf = op->buffer->data; |
46 | StorageScope scope = GetScope(buf); |
47 | if (Enabled(buf.get(), scope)) { |
48 | ICHECK(allow_append_) << op << " " << scope.to_string(); |
49 | AccessEntry e; |
50 | e.threads = env_threads(); |
51 | e.buffer = buf; |
52 | e.dtype = op->dtype.element_of(); |
53 | for (const auto& index : op->indices) { |
54 | e.touched.push_back(arith::IntSet::Vector(index)); |
55 | } |
56 | e.type = kRead; |
57 | e.scope = scope; |
58 | curr_stmt_.access.emplace_back(std::move(e)); |
59 | } |
60 | // traverse child |
61 | StmtExprVisitor::VisitExpr_(op); |
62 | } |
63 | |
64 | void StorageAccessVisitor::VisitStmt_(const BufferStoreNode* op) { |
65 | allow_append_ = true; |
66 | ICHECK_EQ(curr_stmt_.access.size(), 0U); |
67 | curr_stmt_.stmt = op; |
68 | |
69 | Var buf = op->buffer->data; |
70 | StorageScope scope = GetScope(buf); |
71 | if (Enabled(buf.get(), scope)) { |
72 | AccessEntry e; |
73 | e.threads = env_threads(); |
74 | e.buffer = buf; |
75 | e.dtype = op->value.dtype().element_of(); |
76 | for (const auto& index : op->indices) { |
77 | e.touched.push_back(arith::IntSet::Vector(index)); |
78 | } |
79 | e.type = kWrite; |
80 | e.scope = scope; |
81 | curr_stmt_.access.emplace_back(std::move(e)); |
82 | } |
83 | // traverse child |
84 | StmtExprVisitor::VisitStmt_(op); |
85 | // push to the scope |
86 | scope_.back().push_back(curr_stmt_); |
87 | // clear access entry. |
88 | curr_stmt_.access.clear(); |
89 | allow_append_ = false; |
90 | } |
91 | |
92 | void StorageAccessVisitor::VisitStmt_(const EvaluateNode* op) { |
93 | allow_append_ = true; |
94 | ICHECK_EQ(curr_stmt_.access.size(), 0U); |
95 | curr_stmt_.stmt = op; |
96 | StmtExprVisitor::VisitStmt_(op); |
97 | // push to the scope |
98 | if (curr_stmt_.access.size() != 0) { |
99 | scope_.back().push_back(curr_stmt_); |
100 | curr_stmt_.access.clear(); |
101 | } |
102 | allow_append_ = false; |
103 | } |
104 | |
105 | void StorageAccessVisitor::VisitStmt_(const AttrStmtNode* op) { |
106 | if (op->attr_key == attr::double_buffer_write) { |
107 | ICHECK(double_buffer_write_ == nullptr); |
108 | double_buffer_write_ = op->node.as<VarNode>(); |
109 | scope_.push_back(std::vector<StmtEntry>()); |
110 | StmtExprVisitor::VisitStmt_(op); |
111 | StmtEntry s; |
112 | s.stmt = op; |
113 | s.access = Summarize(std::move(scope_.back()), nullptr); |
114 | scope_.pop_back(); |
115 | if (!s.access.empty()) { |
116 | for (AccessEntry& e : s.access) { |
117 | if (e.type == kWrite && e.buffer.get() == double_buffer_write_) { |
118 | e.double_buffer_write = true; |
119 | } |
120 | } |
121 | scope_.back().emplace_back(std::move(s)); |
122 | } |
123 | double_buffer_write_ = nullptr; |
124 | } else if (op->attr_key == attr::coproc_scope) { |
125 | IterVar iv = Downcast<IterVar>(op->node); |
126 | env_threads_.push_back(iv); |
127 | StmtExprVisitor::VisitStmt_(op); |
128 | env_threads_.pop_back(); |
129 | } else if (op->attr_key == attr::thread_extent) { |
130 | IterVar iv = Downcast<IterVar>(op->node); |
131 | env_threads_.push_back(iv); |
132 | if (!in_device_env_) { |
133 | in_device_env_ = true; |
134 | scope_.push_back(std::vector<StmtEntry>()); |
135 | StmtExprVisitor::VisitStmt_(op); |
136 | // no need to take the result as the thread barrier automatically syncs. |
137 | Summarize(std::move(scope_.back()), nullptr); |
138 | in_device_env_ = false; |
139 | scope_.pop_back(); |
140 | } else { |
141 | StmtExprVisitor::VisitStmt_(op); |
142 | } |
143 | env_threads_.pop_back(); |
144 | } else if (op->attr_key == attr::hand_threaded) { |
145 | // skip this pass on blocks that were hand_threaded |
146 | // this avoids control flow and read/write conflicts |
147 | // between hand-threaded kernels and automatic threading |
148 | } else { |
149 | StmtExprVisitor::VisitStmt_(op); |
150 | } |
151 | } |
152 | |
153 | void StorageAccessVisitor::VisitStmt_(const ForNode* op) { |
154 | scope_.push_back(std::vector<StmtEntry>()); |
155 | StmtExprVisitor::VisitStmt_(op); |
156 | StmtEntry s; |
157 | s.stmt = op; |
158 | s.access = Summarize(std::move(scope_.back()), op); |
159 | scope_.pop_back(); |
160 | if (s.access.size() != 0) { |
161 | // relax the touched set to contain all ranges in the loop. |
162 | std::unordered_map<const VarNode*, arith::IntSet> relax_map; |
163 | relax_map[op->loop_var.get()] = |
164 | arith::IntSet::FromRange(Range::FromMinExtent(op->min, op->extent)); |
165 | for (AccessEntry& e : s.access) { |
166 | if (e.buffer.defined()) { |
167 | ICHECK(e.touched.size()); |
168 | Array<arith::IntSet> new_touched; |
169 | for (const auto& touched : e.touched) { |
170 | new_touched.push_back(arith::EvalSet(touched, relax_map)); |
171 | } |
172 | e.touched = std::move(new_touched); |
173 | } |
174 | } |
175 | } |
176 | if (!s.access.empty()) { |
177 | scope_.back().emplace_back(std::move(s)); |
178 | } |
179 | } |
180 | |
181 | void StorageAccessVisitor::VisitStmt_(const IfThenElseNode* op) { |
182 | ++condition_counter_; |
183 | this->VisitExpr(op->condition); |
184 | scope_.push_back(std::vector<StmtEntry>()); |
185 | this->VisitStmt(op->then_case); |
186 | StmtEntry s; |
187 | s.stmt = op; |
188 | s.access = Summarize(std::move(scope_.back()), nullptr); |
189 | scope_.pop_back(); |
190 | if (op->else_case) { |
191 | scope_.push_back(std::vector<StmtEntry>()); |
192 | this->VisitStmt(op->else_case.value()); |
193 | auto v = Summarize(std::move(scope_.back()), nullptr); |
194 | scope_.pop_back(); |
195 | s.access.insert(s.access.end(), v.begin(), v.end()); |
196 | } |
197 | scope_.back().emplace_back(std::move(s)); |
198 | --condition_counter_; |
199 | } |
200 | |
201 | void StorageAccessVisitor::VisitStmt_(const WhileNode* op) { |
202 | ++condition_counter_; |
203 | this->VisitExpr(op->condition); |
204 | scope_.push_back(std::vector<StmtEntry>()); |
205 | this->VisitStmt(op->body); |
206 | StmtEntry s; |
207 | s.stmt = op; |
208 | s.access = Summarize(std::move(scope_.back()), nullptr); |
209 | scope_.pop_back(); |
210 | scope_.back().emplace_back(std::move(s)); |
211 | --condition_counter_; |
212 | } |
213 | |
214 | void StorageAccessVisitor::VisitExpr_(const CallNode* op) { |
215 | if (op->op.same_as(builtin::address_of())) { |
216 | const BufferLoadNode* load = op->args[0].as<BufferLoadNode>(); |
217 | StmtExprVisitor::VisitExpr_(load); |
218 | } else if (op->op.same_as(builtin::tvm_access_ptr())) { |
219 | ICHECK_EQ(op->args.size(), 5U); |
220 | DataType dtype = op->args[0].dtype(); |
221 | const VarNode* buffer = op->args[1].as<VarNode>(); |
222 | PrimExpr offset = op->args[2]; |
223 | PrimExpr extent = op->args[3]; |
224 | const IntImmNode* flag = op->args[4].as<IntImmNode>(); |
225 | StorageScope scope = GetScope(GetRef<Var>(buffer)); |
226 | // The buffer scope. |
227 | if (Enabled(buffer, scope)) { |
228 | ICHECK(allow_append_); |
229 | AccessEntry e; |
230 | e.threads = env_threads(); |
231 | e.dtype = dtype; |
232 | e.buffer = Downcast<Var>(op->args[1]); |
233 | e.touched = {arith::IntSet::FromRange(Range::FromMinExtent(offset, extent))}; |
234 | e.scope = scope; |
235 | if (flag->value & 1) { |
236 | e.type = kRead; |
237 | curr_stmt_.access.emplace_back(e); |
238 | } |
239 | if (flag->value & 2) { |
240 | e.type = kWrite; |
241 | curr_stmt_.access.emplace_back(e); |
242 | } |
243 | } |
244 | StmtExprVisitor::VisitExpr_(op); |
245 | } else if (op->op.same_as(builtin::tvm_storage_sync())) { |
246 | ICHECK(allow_append_); |
247 | const std::string& s = op->args[0].as<StringImmNode>()->value; |
248 | if (s != "warp" ) { |
249 | StorageScope scope = StorageScope::Create(s); |
250 | AccessEntry e; |
251 | e.threads = env_threads(); |
252 | e.type = kSync; |
253 | e.scope = StorageScope::Create(s); |
254 | curr_stmt_.access.emplace_back(std::move(e)); |
255 | } |
256 | } else { |
257 | StmtExprVisitor::VisitExpr_(op); |
258 | } |
259 | } |
260 | |
261 | StorageScope StorageAccessVisitor::GetScope(Var buffer_var) const { |
262 | if (buffer_var->type_annotation.as<PointerTypeNode>()) { |
263 | return StorageScope::Create(GetPtrStorageScope(buffer_var)); |
264 | } |
265 | return StorageScope(); // global by default |
266 | } |
267 | |
268 | } // namespace tir |
269 | } // namespace tvm |
270 | |