1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \brief Utility to make loop nest. |
22 | * \file op_utils.cc |
23 | */ |
24 | #include "op_utils.h" |
25 | |
26 | #include <tvm/te/operation.h> |
27 | #include <tvm/tir/expr.h> |
28 | #include <tvm/tir/stmt_functor.h> |
29 | |
30 | #include <string> |
31 | |
32 | #include "../../runtime/thread_storage_scope.h" |
33 | #include "../schedule/message_passing.h" |
34 | |
35 | namespace tvm { |
36 | namespace te { |
37 | |
38 | using namespace arith; |
39 | using namespace tir; |
40 | |
41 | std::vector<std::vector<Stmt>> MakeLoopNest(const Stage& stage, |
42 | const std::unordered_map<IterVar, Range>& dom_map, |
43 | size_t begin_iter_pos, bool new_loop_var, |
44 | const std::unordered_set<IterVar>& skip_iter, |
45 | std::unordered_map<IterVar, PrimExpr>* p_value_map, |
46 | bool debug_keep_trivial_loop) { |
47 | auto leaf_iter_vars = stage->leaf_iter_vars; |
48 | Stmt no_op = Evaluate(0); |
49 | // create the loop nest |
50 | std::vector<std::vector<Stmt>> nest; |
51 | nest.resize(leaf_iter_vars.size() + 1); |
52 | std::unordered_map<IterVar, PrimExpr>& value_map = *p_value_map; |
53 | |
54 | for (size_t i = begin_iter_pos; i < leaf_iter_vars.size(); ++i) { |
55 | auto iv = leaf_iter_vars[i]; |
56 | if (skip_iter.count(iv) || iv->iter_type == kOpaque) { |
57 | // skip this iteration. |
58 | value_map[iv] = iv->var; |
59 | continue; |
60 | } |
61 | // Bind iv could be another thread. |
62 | IterVar bind_iv = iv; |
63 | if (stage->iter_var_attrs.count(iv)) { |
64 | IterVar bind_thread = stage->iter_var_attrs[iv]->bind_thread; |
65 | if (bind_thread.defined()) bind_iv = bind_thread; |
66 | } |
67 | |
68 | Range dom = dom_map.at(iv); |
69 | |
70 | ICHECK(iv->var.dtype() == dom->min.dtype() && iv->var.dtype() == dom->extent.dtype()) |
71 | << "iter_var type " << iv->var.dtype() << " and domain types (min:" << dom->min.dtype() |
72 | << ", extent:" << dom->extent.dtype() << ") should all be the same" ; |
73 | |
74 | // This is a hack to ensure that the replacing expression has the same |
75 | // dtype as the replacing expression. This happens when a thread/block |
76 | // itervar is bound to another itervar. Because the thread/block itervar |
77 | // has no way to know its correct dtype before it is bound, it defaults to |
78 | // int32. Then the itervar it is bound to may have a different dtype. The |
79 | // thread/block dtype really should be promoted to dtype of what it is |
80 | // bound to (in `bind`) but that would require inplace modification of the |
81 | // itervar. |
82 | // XXX: we will get integer overflow if the bound itervar is greater than int32::max. |
83 | auto promote_to_iv_dtype = [type = iv->var.dtype()](PrimExpr e) { |
84 | return type != e.dtype() ? cast(type, e) : e; |
85 | }; |
86 | |
87 | // initialize the offset and loop_level |
88 | Var var = bind_iv->var; |
89 | |
90 | // Mark the iter var in the IR, to remember the point |
91 | if (bind_iv->thread_tag.length() == 0) { |
92 | // Only generate new loop if we're not bound to a thread. |
93 | if (new_loop_var) { |
94 | var = Var(iv->var->name_hint + ".init" , bind_iv->var.dtype()); |
95 | } |
96 | |
97 | ForKind kind = ForKind::kSerial; |
98 | IterVarAttr it_attr; |
99 | if (stage->iter_var_attrs.count(iv)) { |
100 | it_attr = stage->iter_var_attrs[iv]; |
101 | } |
102 | if (it_attr.defined()) { |
103 | switch (it_attr->iter_type) { |
104 | case kUnrolled: |
105 | kind = ForKind::kUnrolled; |
106 | break; |
107 | case kVectorized: |
108 | kind = ForKind::kVectorized; |
109 | break; |
110 | case kParallelized: |
111 | kind = ForKind::kParallel; |
112 | break; |
113 | case kDataPar: |
114 | break; |
115 | case kTensorized: |
116 | break; |
117 | default: |
118 | LOG(FATAL) << "Unknown iter type" << it_attr->iter_type << " in the iter_var_attrs" ; |
119 | } |
120 | ICHECK_EQ(it_attr->pragma_keys.size(), it_attr->pragma_values.size()); |
121 | for (size_t k = 0; k < it_attr->pragma_keys.size(); ++k) { |
122 | const std::string& pkey = it_attr->pragma_keys[k].as<StringImmNode>()->value; |
123 | PrimExpr pvalue = it_attr->pragma_values[k]; |
124 | if (!pvalue.defined()) { |
125 | pvalue = make_const(DataType::Int(32), 1); |
126 | } |
127 | nest[i + 1].emplace_back( |
128 | AttrStmt(iv, tir::attr::pragma_scope_prefix + pkey, pvalue, no_op)); |
129 | } |
130 | } |
131 | if (!debug_keep_trivial_loop && is_one(dom->extent)) { |
132 | nest[i + 1].emplace_back(LetStmt(var, dom->min, no_op)); |
133 | value_map[iv] = dom->min; |
134 | } else if (is_zero(dom->min)) { |
135 | nest[i + 1].emplace_back(For(var, 0, dom->extent, kind, no_op)); |
136 | value_map[iv] = promote_to_iv_dtype(var); |
137 | } else { |
138 | Var idx(bind_iv->var->name_hint + ".idx" , iv->var.dtype()); |
139 | nest[i + 1].emplace_back(For(idx, 0, dom->extent, kind, no_op)); |
140 | PrimExpr new_value = dom->min + idx; |
141 | value_map[iv] = new_value; |
142 | nest[i + 1].emplace_back(LetStmt(var, new_value, no_op)); |
143 | } |
144 | if (it_attr.defined() && it_attr->prefetch_data.size() != 0) { |
145 | ICHECK(!is_one(dom->extent)) << "Cannot prefetch on trivial loop with extent=1" ; |
146 | ICHECK_EQ(it_attr->prefetch_data.size(), it_attr->prefetch_offset.size()); |
147 | for (size_t j = 0; j < it_attr->prefetch_data.size(); ++j) { |
148 | nest[i + 1].emplace_back(AttrStmt(it_attr->prefetch_data[j], tir::attr::prefetch_scope, |
149 | it_attr->prefetch_offset[j], no_op)); |
150 | } |
151 | } |
152 | } else if (bind_iv->thread_tag == "vthread" || bind_iv->thread_tag == "cthread" ) { |
153 | // virtual thread |
154 | // Always restrict threaded IterVar to starts from 0. |
155 | ICHECK(is_zero(dom->min)); |
156 | ICHECK(is_positive_const(dom->extent)); |
157 | // annotate the extent of the IterVar |
158 | nest[i + 1].emplace_back(AttrStmt(bind_iv, tir::attr::virtual_thread, dom->extent, no_op)); |
159 | value_map[iv] = promote_to_iv_dtype(var); |
160 | } else if (bind_iv->thread_tag == "pipeline" ) { |
161 | // pipeline marker. |
162 | ICHECK(is_zero(dom->min)); |
163 | ICHECK(is_one(dom->extent)); |
164 | // annotate the extent of the IterVar |
165 | nest[i + 1].emplace_back( |
166 | AttrStmt(bind_iv, tir::attr::pipeline_exec_scope, dom->extent, no_op)); |
167 | value_map[iv] = dom->min; |
168 | } else { |
169 | // Always restrict threaded IterVar to starts from 0. |
170 | ICHECK(is_zero(dom->min)) << "Itervar " << iv << " must start at zero, but it starts at " |
171 | << dom->min; |
172 | // annotate the extent of the IterVar |
173 | nest[i + 1].emplace_back(AttrStmt(bind_iv, tir::attr::thread_extent, dom->extent, no_op)); |
174 | if (!debug_keep_trivial_loop && is_one(dom->extent)) { |
175 | value_map[iv] = dom->min; |
176 | } else if (stage->scope == "" ) { |
177 | value_map[iv] = promote_to_iv_dtype(var); |
178 | } else { |
179 | runtime::ThreadScope ts = runtime::ThreadScope::Create(bind_iv->thread_tag); |
180 | runtime::StorageScope ss = runtime::StorageScope::Create(stage->scope); |
181 | if (static_cast<int>(ss.rank) <= ts.rank) { |
182 | value_map[iv] = promote_to_iv_dtype(var); |
183 | } else if (stage->scope == "warp" && ts.rank == 1) { |
184 | // To determine whether a thread index is inside or outside a warp, we need |
185 | // to know the thread extent. We leave a warning for now. |
186 | if (ts.dim_index == 0) { |
187 | value_map[iv] = promote_to_iv_dtype(var); |
188 | } else { |
189 | LOG(WARNING) |
190 | << "WARNING: threadIdx.y or threadIdx.z accessing warp-scope memory detected. " |
191 | << "TVM assumes only threadIdx.x indicates threads inside a warp, " |
192 | << "while threadIdx.y and threadIdx.z indicates different warps." ; |
193 | value_map[iv] = dom->min; |
194 | } |
195 | } else { |
196 | value_map[iv] = dom->min; |
197 | } |
198 | } |
199 | } |
200 | // annotate the extent of the IterVar |
201 | if (!new_loop_var) { |
202 | nest[i + 1].emplace_back(AttrStmt(iv, tir::attr::loop_scope, iv->var, no_op)); |
203 | } |
204 | } |
205 | // message passing to get offset of root iter vars. |
206 | te::PassUpIndex(stage, dom_map, &value_map); |
207 | return nest; |
208 | } |
209 | |
210 | std::vector<Stmt> MakeIfNest(const std::vector<PrimExpr>& predicates) { |
211 | Stmt no_op = Evaluate(0); |
212 | std::vector<Stmt> nest; |
213 | for (const PrimExpr& cond : predicates) { |
214 | nest.emplace_back(IfThenElse(cond, no_op)); |
215 | } |
216 | return nest; |
217 | } |
218 | |
219 | // replacer to replace tensors |
220 | class TensorReplacer : public tir::StmtExprMutator { |
221 | public: |
222 | explicit TensorReplacer(const std::unordered_map<Tensor, Tensor>& vmap) : vmap_(vmap) {} |
223 | |
224 | PrimExpr VisitExpr_(const tir::ProducerLoadNode* op) final { |
225 | PrimExpr expr = StmtExprMutator::VisitExpr_(op); |
226 | op = expr.as<tir::ProducerLoadNode>(); |
227 | ICHECK(op != nullptr); |
228 | |
229 | Tensor t = Downcast<Tensor>(op->producer); |
230 | auto it = vmap_.find(t); |
231 | if (it != vmap_.end()) { |
232 | found = true; |
233 | return tir::ProducerLoad(it->second, op->indices); |
234 | } else { |
235 | return expr; |
236 | } |
237 | } |
238 | |
239 | // whether it is found. |
240 | bool found{false}; |
241 | |
242 | private: |
243 | const std::unordered_map<Tensor, Tensor>& vmap_; |
244 | }; |
245 | |
246 | Stmt ReplaceTensor(Stmt stmt, const std::unordered_map<Tensor, Tensor>& replace) { |
247 | TensorReplacer repl(replace); |
248 | Stmt ret = repl(stmt); |
249 | return repl.found ? ret : stmt; |
250 | } |
251 | PrimExpr ReplaceTensor(PrimExpr expr, const std::unordered_map<Tensor, Tensor>& replace) { |
252 | TensorReplacer repl(replace); |
253 | PrimExpr ret = repl(expr); |
254 | return repl.found ? ret : expr; |
255 | } |
256 | |
257 | Stmt Substitute(Stmt s, const std::unordered_map<IterVar, PrimExpr>& value_map) { |
258 | std::unordered_map<const VarNode*, PrimExpr> init; |
259 | for (const auto& kv : value_map) { |
260 | init[kv.first->var.get()] = kv.second; |
261 | } |
262 | return tir::Substitute(s, init); |
263 | } |
264 | |
265 | PrimExpr Substitute(PrimExpr s, const std::unordered_map<IterVar, PrimExpr>& value_map) { |
266 | std::unordered_map<const VarNode*, PrimExpr> init; |
267 | for (const auto& kv : value_map) { |
268 | init[kv.first->var.get()] = kv.second; |
269 | } |
270 | return tir::Substitute(s, init); |
271 | } |
272 | |
273 | IterVarType ForKindToIterVarType(tir::ForKind kind) { |
274 | switch (kind) { |
275 | case ForKind::kSerial: |
276 | return kDataPar; |
277 | case ForKind::kParallel: |
278 | return kParallelized; |
279 | case ForKind::kVectorized: |
280 | return kVectorized; |
281 | case ForKind::kUnrolled: |
282 | return kUnrolled; |
283 | default: |
284 | return kDataPar; |
285 | } |
286 | } |
287 | |
288 | tir::ForKind IterVarTypeToForKind(IterVarType iter_type) { |
289 | switch (iter_type) { |
290 | case kDataPar: |
291 | return ForKind::kSerial; |
292 | case kParallelized: |
293 | return ForKind::kParallel; |
294 | case kVectorized: |
295 | return ForKind::kVectorized; |
296 | case kUnrolled: |
297 | return ForKind::kUnrolled; |
298 | default: |
299 | return ForKind::kSerial; |
300 | } |
301 | } |
302 | |
303 | } // namespace te |
304 | } // namespace tvm |
305 | |