1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file storage_access.cc
22 */
23#include "storage_access.h"
24
25#include <tvm/target/target_info.h>
26#include <tvm/tir/op.h>
27
28#include <string>
29#include <utility>
30
31#include "ir_utils.h"
32
33namespace tvm {
34namespace tir {
35
36void StorageAccessVisitor::VisitExpr_(const LoadNode* op) {
37 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
38}
39
40void StorageAccessVisitor::VisitStmt_(const StoreNode* op) {
41 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
42}
43
44void StorageAccessVisitor::VisitExpr_(const BufferLoadNode* op) {
45 Var buf = op->buffer->data;
46 StorageScope scope = GetScope(buf);
47 if (Enabled(buf.get(), scope)) {
48 ICHECK(allow_append_) << op << " " << scope.to_string();
49 AccessEntry e;
50 e.threads = env_threads();
51 e.buffer = buf;
52 e.dtype = op->dtype.element_of();
53 for (const auto& index : op->indices) {
54 e.touched.push_back(arith::IntSet::Vector(index));
55 }
56 e.type = kRead;
57 e.scope = scope;
58 curr_stmt_.access.emplace_back(std::move(e));
59 }
60 // traverse child
61 StmtExprVisitor::VisitExpr_(op);
62}
63
64void StorageAccessVisitor::VisitStmt_(const BufferStoreNode* op) {
65 allow_append_ = true;
66 ICHECK_EQ(curr_stmt_.access.size(), 0U);
67 curr_stmt_.stmt = op;
68
69 Var buf = op->buffer->data;
70 StorageScope scope = GetScope(buf);
71 if (Enabled(buf.get(), scope)) {
72 AccessEntry e;
73 e.threads = env_threads();
74 e.buffer = buf;
75 e.dtype = op->value.dtype().element_of();
76 for (const auto& index : op->indices) {
77 e.touched.push_back(arith::IntSet::Vector(index));
78 }
79 e.type = kWrite;
80 e.scope = scope;
81 curr_stmt_.access.emplace_back(std::move(e));
82 }
83 // traverse child
84 StmtExprVisitor::VisitStmt_(op);
85 // push to the scope
86 scope_.back().push_back(curr_stmt_);
87 // clear access entry.
88 curr_stmt_.access.clear();
89 allow_append_ = false;
90}
91
92void StorageAccessVisitor::VisitStmt_(const EvaluateNode* op) {
93 allow_append_ = true;
94 ICHECK_EQ(curr_stmt_.access.size(), 0U);
95 curr_stmt_.stmt = op;
96 StmtExprVisitor::VisitStmt_(op);
97 // push to the scope
98 if (curr_stmt_.access.size() != 0) {
99 scope_.back().push_back(curr_stmt_);
100 curr_stmt_.access.clear();
101 }
102 allow_append_ = false;
103}
104
105void StorageAccessVisitor::VisitStmt_(const AttrStmtNode* op) {
106 if (op->attr_key == attr::double_buffer_write) {
107 ICHECK(double_buffer_write_ == nullptr);
108 double_buffer_write_ = op->node.as<VarNode>();
109 scope_.push_back(std::vector<StmtEntry>());
110 StmtExprVisitor::VisitStmt_(op);
111 StmtEntry s;
112 s.stmt = op;
113 s.access = Summarize(std::move(scope_.back()), nullptr);
114 scope_.pop_back();
115 if (!s.access.empty()) {
116 for (AccessEntry& e : s.access) {
117 if (e.type == kWrite && e.buffer.get() == double_buffer_write_) {
118 e.double_buffer_write = true;
119 }
120 }
121 scope_.back().emplace_back(std::move(s));
122 }
123 double_buffer_write_ = nullptr;
124 } else if (op->attr_key == attr::coproc_scope) {
125 IterVar iv = Downcast<IterVar>(op->node);
126 env_threads_.push_back(iv);
127 StmtExprVisitor::VisitStmt_(op);
128 env_threads_.pop_back();
129 } else if (op->attr_key == attr::thread_extent) {
130 IterVar iv = Downcast<IterVar>(op->node);
131 env_threads_.push_back(iv);
132 if (!in_device_env_) {
133 in_device_env_ = true;
134 scope_.push_back(std::vector<StmtEntry>());
135 StmtExprVisitor::VisitStmt_(op);
136 // no need to take the result as the thread barrier automatically syncs.
137 Summarize(std::move(scope_.back()), nullptr);
138 in_device_env_ = false;
139 scope_.pop_back();
140 } else {
141 StmtExprVisitor::VisitStmt_(op);
142 }
143 env_threads_.pop_back();
144 } else if (op->attr_key == attr::hand_threaded) {
145 // skip this pass on blocks that were hand_threaded
146 // this avoids control flow and read/write conflicts
147 // between hand-threaded kernels and automatic threading
148 } else {
149 StmtExprVisitor::VisitStmt_(op);
150 }
151}
152
153void StorageAccessVisitor::VisitStmt_(const ForNode* op) {
154 scope_.push_back(std::vector<StmtEntry>());
155 StmtExprVisitor::VisitStmt_(op);
156 StmtEntry s;
157 s.stmt = op;
158 s.access = Summarize(std::move(scope_.back()), op);
159 scope_.pop_back();
160 if (s.access.size() != 0) {
161 // relax the touched set to contain all ranges in the loop.
162 std::unordered_map<const VarNode*, arith::IntSet> relax_map;
163 relax_map[op->loop_var.get()] =
164 arith::IntSet::FromRange(Range::FromMinExtent(op->min, op->extent));
165 for (AccessEntry& e : s.access) {
166 if (e.buffer.defined()) {
167 ICHECK(e.touched.size());
168 Array<arith::IntSet> new_touched;
169 for (const auto& touched : e.touched) {
170 new_touched.push_back(arith::EvalSet(touched, relax_map));
171 }
172 e.touched = std::move(new_touched);
173 }
174 }
175 }
176 if (!s.access.empty()) {
177 scope_.back().emplace_back(std::move(s));
178 }
179}
180
181void StorageAccessVisitor::VisitStmt_(const IfThenElseNode* op) {
182 ++condition_counter_;
183 this->VisitExpr(op->condition);
184 scope_.push_back(std::vector<StmtEntry>());
185 this->VisitStmt(op->then_case);
186 StmtEntry s;
187 s.stmt = op;
188 s.access = Summarize(std::move(scope_.back()), nullptr);
189 scope_.pop_back();
190 if (op->else_case) {
191 scope_.push_back(std::vector<StmtEntry>());
192 this->VisitStmt(op->else_case.value());
193 auto v = Summarize(std::move(scope_.back()), nullptr);
194 scope_.pop_back();
195 s.access.insert(s.access.end(), v.begin(), v.end());
196 }
197 scope_.back().emplace_back(std::move(s));
198 --condition_counter_;
199}
200
201void StorageAccessVisitor::VisitStmt_(const WhileNode* op) {
202 ++condition_counter_;
203 this->VisitExpr(op->condition);
204 scope_.push_back(std::vector<StmtEntry>());
205 this->VisitStmt(op->body);
206 StmtEntry s;
207 s.stmt = op;
208 s.access = Summarize(std::move(scope_.back()), nullptr);
209 scope_.pop_back();
210 scope_.back().emplace_back(std::move(s));
211 --condition_counter_;
212}
213
214void StorageAccessVisitor::VisitExpr_(const CallNode* op) {
215 if (op->op.same_as(builtin::address_of())) {
216 const BufferLoadNode* load = op->args[0].as<BufferLoadNode>();
217 StmtExprVisitor::VisitExpr_(load);
218 } else if (op->op.same_as(builtin::tvm_access_ptr())) {
219 ICHECK_EQ(op->args.size(), 5U);
220 DataType dtype = op->args[0].dtype();
221 const VarNode* buffer = op->args[1].as<VarNode>();
222 PrimExpr offset = op->args[2];
223 PrimExpr extent = op->args[3];
224 const IntImmNode* flag = op->args[4].as<IntImmNode>();
225 StorageScope scope = GetScope(GetRef<Var>(buffer));
226 // The buffer scope.
227 if (Enabled(buffer, scope)) {
228 ICHECK(allow_append_);
229 AccessEntry e;
230 e.threads = env_threads();
231 e.dtype = dtype;
232 e.buffer = Downcast<Var>(op->args[1]);
233 e.touched = {arith::IntSet::FromRange(Range::FromMinExtent(offset, extent))};
234 e.scope = scope;
235 if (flag->value & 1) {
236 e.type = kRead;
237 curr_stmt_.access.emplace_back(e);
238 }
239 if (flag->value & 2) {
240 e.type = kWrite;
241 curr_stmt_.access.emplace_back(e);
242 }
243 }
244 StmtExprVisitor::VisitExpr_(op);
245 } else if (op->op.same_as(builtin::tvm_storage_sync())) {
246 ICHECK(allow_append_);
247 const std::string& s = op->args[0].as<StringImmNode>()->value;
248 if (s != "warp") {
249 StorageScope scope = StorageScope::Create(s);
250 AccessEntry e;
251 e.threads = env_threads();
252 e.type = kSync;
253 e.scope = StorageScope::Create(s);
254 curr_stmt_.access.emplace_back(std::move(e));
255 }
256 } else {
257 StmtExprVisitor::VisitExpr_(op);
258 }
259}
260
261StorageScope StorageAccessVisitor::GetScope(Var buffer_var) const {
262 if (buffer_var->type_annotation.as<PointerTypeNode>()) {
263 return StorageScope::Create(GetPtrStorageScope(buffer_var));
264 }
265 return StorageScope(); // global by default
266}
267
268} // namespace tir
269} // namespace tvm
270