1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \brief Utility to make loop nest.
22 * \file op_utils.cc
23 */
24#include "op_utils.h"
25
26#include <tvm/te/operation.h>
27#include <tvm/tir/expr.h>
28#include <tvm/tir/stmt_functor.h>
29
30#include <string>
31
32#include "../../runtime/thread_storage_scope.h"
33#include "../schedule/message_passing.h"
34
35namespace tvm {
36namespace te {
37
38using namespace arith;
39using namespace tir;
40
41std::vector<std::vector<Stmt>> MakeLoopNest(const Stage& stage,
42 const std::unordered_map<IterVar, Range>& dom_map,
43 size_t begin_iter_pos, bool new_loop_var,
44 const std::unordered_set<IterVar>& skip_iter,
45 std::unordered_map<IterVar, PrimExpr>* p_value_map,
46 bool debug_keep_trivial_loop) {
47 auto leaf_iter_vars = stage->leaf_iter_vars;
48 Stmt no_op = Evaluate(0);
49 // create the loop nest
50 std::vector<std::vector<Stmt>> nest;
51 nest.resize(leaf_iter_vars.size() + 1);
52 std::unordered_map<IterVar, PrimExpr>& value_map = *p_value_map;
53
54 for (size_t i = begin_iter_pos; i < leaf_iter_vars.size(); ++i) {
55 auto iv = leaf_iter_vars[i];
56 if (skip_iter.count(iv) || iv->iter_type == kOpaque) {
57 // skip this iteration.
58 value_map[iv] = iv->var;
59 continue;
60 }
61 // Bind iv could be another thread.
62 IterVar bind_iv = iv;
63 if (stage->iter_var_attrs.count(iv)) {
64 IterVar bind_thread = stage->iter_var_attrs[iv]->bind_thread;
65 if (bind_thread.defined()) bind_iv = bind_thread;
66 }
67
68 Range dom = dom_map.at(iv);
69
70 ICHECK(iv->var.dtype() == dom->min.dtype() && iv->var.dtype() == dom->extent.dtype())
71 << "iter_var type " << iv->var.dtype() << " and domain types (min:" << dom->min.dtype()
72 << ", extent:" << dom->extent.dtype() << ") should all be the same";
73
74 // This is a hack to ensure that the replacing expression has the same
75 // dtype as the replacing expression. This happens when a thread/block
76 // itervar is bound to another itervar. Because the thread/block itervar
77 // has no way to know its correct dtype before it is bound, it defaults to
78 // int32. Then the itervar it is bound to may have a different dtype. The
79 // thread/block dtype really should be promoted to dtype of what it is
80 // bound to (in `bind`) but that would require inplace modification of the
81 // itervar.
82 // XXX: we will get integer overflow if the bound itervar is greater than int32::max.
83 auto promote_to_iv_dtype = [type = iv->var.dtype()](PrimExpr e) {
84 return type != e.dtype() ? cast(type, e) : e;
85 };
86
87 // initialize the offset and loop_level
88 Var var = bind_iv->var;
89
90 // Mark the iter var in the IR, to remember the point
91 if (bind_iv->thread_tag.length() == 0) {
92 // Only generate new loop if we're not bound to a thread.
93 if (new_loop_var) {
94 var = Var(iv->var->name_hint + ".init", bind_iv->var.dtype());
95 }
96
97 ForKind kind = ForKind::kSerial;
98 IterVarAttr it_attr;
99 if (stage->iter_var_attrs.count(iv)) {
100 it_attr = stage->iter_var_attrs[iv];
101 }
102 if (it_attr.defined()) {
103 switch (it_attr->iter_type) {
104 case kUnrolled:
105 kind = ForKind::kUnrolled;
106 break;
107 case kVectorized:
108 kind = ForKind::kVectorized;
109 break;
110 case kParallelized:
111 kind = ForKind::kParallel;
112 break;
113 case kDataPar:
114 break;
115 case kTensorized:
116 break;
117 default:
118 LOG(FATAL) << "Unknown iter type" << it_attr->iter_type << " in the iter_var_attrs";
119 }
120 ICHECK_EQ(it_attr->pragma_keys.size(), it_attr->pragma_values.size());
121 for (size_t k = 0; k < it_attr->pragma_keys.size(); ++k) {
122 const std::string& pkey = it_attr->pragma_keys[k].as<StringImmNode>()->value;
123 PrimExpr pvalue = it_attr->pragma_values[k];
124 if (!pvalue.defined()) {
125 pvalue = make_const(DataType::Int(32), 1);
126 }
127 nest[i + 1].emplace_back(
128 AttrStmt(iv, tir::attr::pragma_scope_prefix + pkey, pvalue, no_op));
129 }
130 }
131 if (!debug_keep_trivial_loop && is_one(dom->extent)) {
132 nest[i + 1].emplace_back(LetStmt(var, dom->min, no_op));
133 value_map[iv] = dom->min;
134 } else if (is_zero(dom->min)) {
135 nest[i + 1].emplace_back(For(var, 0, dom->extent, kind, no_op));
136 value_map[iv] = promote_to_iv_dtype(var);
137 } else {
138 Var idx(bind_iv->var->name_hint + ".idx", iv->var.dtype());
139 nest[i + 1].emplace_back(For(idx, 0, dom->extent, kind, no_op));
140 PrimExpr new_value = dom->min + idx;
141 value_map[iv] = new_value;
142 nest[i + 1].emplace_back(LetStmt(var, new_value, no_op));
143 }
144 if (it_attr.defined() && it_attr->prefetch_data.size() != 0) {
145 ICHECK(!is_one(dom->extent)) << "Cannot prefetch on trivial loop with extent=1";
146 ICHECK_EQ(it_attr->prefetch_data.size(), it_attr->prefetch_offset.size());
147 for (size_t j = 0; j < it_attr->prefetch_data.size(); ++j) {
148 nest[i + 1].emplace_back(AttrStmt(it_attr->prefetch_data[j], tir::attr::prefetch_scope,
149 it_attr->prefetch_offset[j], no_op));
150 }
151 }
152 } else if (bind_iv->thread_tag == "vthread" || bind_iv->thread_tag == "cthread") {
153 // virtual thread
154 // Always restrict threaded IterVar to starts from 0.
155 ICHECK(is_zero(dom->min));
156 ICHECK(is_positive_const(dom->extent));
157 // annotate the extent of the IterVar
158 nest[i + 1].emplace_back(AttrStmt(bind_iv, tir::attr::virtual_thread, dom->extent, no_op));
159 value_map[iv] = promote_to_iv_dtype(var);
160 } else if (bind_iv->thread_tag == "pipeline") {
161 // pipeline marker.
162 ICHECK(is_zero(dom->min));
163 ICHECK(is_one(dom->extent));
164 // annotate the extent of the IterVar
165 nest[i + 1].emplace_back(
166 AttrStmt(bind_iv, tir::attr::pipeline_exec_scope, dom->extent, no_op));
167 value_map[iv] = dom->min;
168 } else {
169 // Always restrict threaded IterVar to starts from 0.
170 ICHECK(is_zero(dom->min)) << "Itervar " << iv << " must start at zero, but it starts at "
171 << dom->min;
172 // annotate the extent of the IterVar
173 nest[i + 1].emplace_back(AttrStmt(bind_iv, tir::attr::thread_extent, dom->extent, no_op));
174 if (!debug_keep_trivial_loop && is_one(dom->extent)) {
175 value_map[iv] = dom->min;
176 } else if (stage->scope == "") {
177 value_map[iv] = promote_to_iv_dtype(var);
178 } else {
179 runtime::ThreadScope ts = runtime::ThreadScope::Create(bind_iv->thread_tag);
180 runtime::StorageScope ss = runtime::StorageScope::Create(stage->scope);
181 if (static_cast<int>(ss.rank) <= ts.rank) {
182 value_map[iv] = promote_to_iv_dtype(var);
183 } else if (stage->scope == "warp" && ts.rank == 1) {
184 // To determine whether a thread index is inside or outside a warp, we need
185 // to know the thread extent. We leave a warning for now.
186 if (ts.dim_index == 0) {
187 value_map[iv] = promote_to_iv_dtype(var);
188 } else {
189 LOG(WARNING)
190 << "WARNING: threadIdx.y or threadIdx.z accessing warp-scope memory detected. "
191 << "TVM assumes only threadIdx.x indicates threads inside a warp, "
192 << "while threadIdx.y and threadIdx.z indicates different warps.";
193 value_map[iv] = dom->min;
194 }
195 } else {
196 value_map[iv] = dom->min;
197 }
198 }
199 }
200 // annotate the extent of the IterVar
201 if (!new_loop_var) {
202 nest[i + 1].emplace_back(AttrStmt(iv, tir::attr::loop_scope, iv->var, no_op));
203 }
204 }
205 // message passing to get offset of root iter vars.
206 te::PassUpIndex(stage, dom_map, &value_map);
207 return nest;
208}
209
210std::vector<Stmt> MakeIfNest(const std::vector<PrimExpr>& predicates) {
211 Stmt no_op = Evaluate(0);
212 std::vector<Stmt> nest;
213 for (const PrimExpr& cond : predicates) {
214 nest.emplace_back(IfThenElse(cond, no_op));
215 }
216 return nest;
217}
218
219// replacer to replace tensors
220class TensorReplacer : public tir::StmtExprMutator {
221 public:
222 explicit TensorReplacer(const std::unordered_map<Tensor, Tensor>& vmap) : vmap_(vmap) {}
223
224 PrimExpr VisitExpr_(const tir::ProducerLoadNode* op) final {
225 PrimExpr expr = StmtExprMutator::VisitExpr_(op);
226 op = expr.as<tir::ProducerLoadNode>();
227 ICHECK(op != nullptr);
228
229 Tensor t = Downcast<Tensor>(op->producer);
230 auto it = vmap_.find(t);
231 if (it != vmap_.end()) {
232 found = true;
233 return tir::ProducerLoad(it->second, op->indices);
234 } else {
235 return expr;
236 }
237 }
238
239 // whether it is found.
240 bool found{false};
241
242 private:
243 const std::unordered_map<Tensor, Tensor>& vmap_;
244};
245
246Stmt ReplaceTensor(Stmt stmt, const std::unordered_map<Tensor, Tensor>& replace) {
247 TensorReplacer repl(replace);
248 Stmt ret = repl(stmt);
249 return repl.found ? ret : stmt;
250}
251PrimExpr ReplaceTensor(PrimExpr expr, const std::unordered_map<Tensor, Tensor>& replace) {
252 TensorReplacer repl(replace);
253 PrimExpr ret = repl(expr);
254 return repl.found ? ret : expr;
255}
256
257Stmt Substitute(Stmt s, const std::unordered_map<IterVar, PrimExpr>& value_map) {
258 std::unordered_map<const VarNode*, PrimExpr> init;
259 for (const auto& kv : value_map) {
260 init[kv.first->var.get()] = kv.second;
261 }
262 return tir::Substitute(s, init);
263}
264
265PrimExpr Substitute(PrimExpr s, const std::unordered_map<IterVar, PrimExpr>& value_map) {
266 std::unordered_map<const VarNode*, PrimExpr> init;
267 for (const auto& kv : value_map) {
268 init[kv.first->var.get()] = kv.second;
269 }
270 return tir::Substitute(s, init);
271}
272
273IterVarType ForKindToIterVarType(tir::ForKind kind) {
274 switch (kind) {
275 case ForKind::kSerial:
276 return kDataPar;
277 case ForKind::kParallel:
278 return kParallelized;
279 case ForKind::kVectorized:
280 return kVectorized;
281 case ForKind::kUnrolled:
282 return kUnrolled;
283 default:
284 return kDataPar;
285 }
286}
287
288tir::ForKind IterVarTypeToForKind(IterVarType iter_type) {
289 switch (iter_type) {
290 case kDataPar:
291 return ForKind::kSerial;
292 case kParallelized:
293 return ForKind::kParallel;
294 case kVectorized:
295 return ForKind::kVectorized;
296 case kUnrolled:
297 return ForKind::kUnrolled;
298 default:
299 return ForKind::kSerial;
300 }
301}
302
303} // namespace te
304} // namespace tvm
305