1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file ir_utils.cc |
22 | * \brief Helper functions to construct and compose IR nodes. |
23 | */ |
24 | #include "ir_utils.h" |
25 | |
26 | #include <tvm/arith/analyzer.h> |
27 | #include <tvm/arith/int_solver.h> |
28 | #include <tvm/tir/stmt_functor.h> |
29 | |
30 | #include <unordered_map> |
31 | #include <unordered_set> |
32 | #include <utility> |
33 | |
34 | namespace tvm { |
35 | namespace tir { |
36 | |
37 | Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) { |
38 | // use reverse iteration |
39 | for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) { |
40 | Stmt s = *ri; |
41 | if (const auto* for_ = s.as<ForNode>()) { |
42 | auto n = make_object<ForNode>(*for_); |
43 | ICHECK(is_no_op(n->body)); |
44 | n->body = body; |
45 | body = Stmt(n); |
46 | } else if (const auto* let = s.as<LetStmtNode>()) { |
47 | auto n = make_object<LetStmtNode>(*let); |
48 | ICHECK(is_no_op(n->body)); |
49 | n->body = body; |
50 | body = Stmt(n); |
51 | } else if (const auto* attr = s.as<AttrStmtNode>()) { |
52 | auto n = make_object<AttrStmtNode>(*attr); |
53 | ICHECK(is_no_op(n->body)); |
54 | n->body = body; |
55 | body = Stmt(n); |
56 | } else if (const auto* ite = s.as<IfThenElseNode>()) { |
57 | auto n = make_object<IfThenElseNode>(*ite); |
58 | ICHECK(is_no_op(n->then_case)); |
59 | ICHECK(!n->else_case); |
60 | n->then_case = body; |
61 | body = Stmt(n); |
62 | } else if (const auto* seq = s.as<SeqStmtNode>()) { |
63 | auto n = make_object<SeqStmtNode>(*seq); |
64 | ICHECK(n->size() != 0 && is_no_op(n->seq[n->size() - 1])); |
65 | n->seq.Set(n->size() - 1, body); |
66 | body = Stmt(n); |
67 | } else if (const auto* assert_ = s.as<AssertStmtNode>()) { |
68 | auto n = make_object<AssertStmtNode>(*assert_); |
69 | ICHECK(is_no_op(n->body)); |
70 | n->body = body; |
71 | body = Stmt(n); |
72 | } else if (const auto* alloc = s.as<AllocateNode>()) { |
73 | auto n = make_object<AllocateNode>(*alloc); |
74 | ICHECK(is_no_op(n->body)); |
75 | n->body = body; |
76 | body = Stmt(n); |
77 | } else { |
78 | LOG(FATAL) << "not supported nest type" ; |
79 | } |
80 | } |
81 | return body; |
82 | } |
83 | |
84 | Stmt MergeNest(const std::vector<std::vector<Stmt>>& nest, Stmt body) { |
85 | for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) { |
86 | body = MergeNest(*ri, body); |
87 | } |
88 | return body; |
89 | } |
90 | |
91 | class IRConvertSSA final : public StmtExprMutator { |
92 | public: |
93 | PrimExpr VisitExpr_(const VarNode* op) final { |
94 | if (scope_.count(op) && !scope_[op].empty()) { |
95 | return scope_[op].back(); |
96 | } else { |
97 | return GetRef<PrimExpr>(op); |
98 | } |
99 | } |
100 | PrimExpr VisitExpr_(const LetNode* op) final { |
101 | const Var& v = op->var; |
102 | if (defined_.count(v.get())) { |
103 | PrimExpr value = this->VisitExpr(op->value); |
104 | ScopedRedefine redefine(this, v); |
105 | PrimExpr body = this->VisitExpr(op->body); |
106 | return Let(redefine.new_var, value, body); |
107 | } else { |
108 | defined_.insert(v.get()); |
109 | return StmtExprMutator::VisitExpr_(op); |
110 | } |
111 | } |
112 | |
113 | PrimExpr VisitExpr_(const LoadNode* op) final { |
114 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
115 | } |
116 | |
117 | Stmt VisitStmt_(const StoreNode* op) final { |
118 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
119 | } |
120 | |
121 | PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
122 | auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); |
123 | auto output = VisitBufferAccess(std::move(node)); |
124 | return std::move(output); |
125 | } |
126 | |
127 | Stmt VisitStmt_(const BufferStoreNode* op) final { |
128 | auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); |
129 | auto output = VisitBufferAccess(std::move(node)); |
130 | return std::move(output); |
131 | } |
132 | |
133 | Stmt VisitStmt_(const DeclBufferNode* op) final { |
134 | DeclBuffer decl = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op)); |
135 | Buffer new_buffer = GetRemappedBuffer(decl->buffer); |
136 | if (!new_buffer.same_as(decl->buffer)) { |
137 | decl.CopyOnWrite()->buffer = std::move(new_buffer); |
138 | } |
139 | return std::move(decl); |
140 | } |
141 | |
142 | template <typename Node> |
143 | Node VisitBufferAccess(Node node) { |
144 | Buffer new_buf = GetRemappedBuffer(node->buffer); |
145 | if (!new_buf.same_as(node->buffer)) { |
146 | auto writer = node.CopyOnWrite(); |
147 | writer->buffer = new_buf; |
148 | } |
149 | |
150 | return node; |
151 | } |
152 | |
153 | Buffer GetRemappedBuffer(Buffer buf) { |
154 | // Determine the buffer var that should be in the updated buffer, |
155 | // given the current scope. If no redefines are present, then the |
156 | // buffer var is unchanged. |
157 | Var new_buffer_var = buf->data; |
158 | auto var_it = scope_.find(buf->data.get()); |
159 | if (var_it != scope_.end() && !var_it->second.empty()) { |
160 | new_buffer_var = var_it->second.back(); |
161 | } |
162 | |
163 | // If no mapping is required, return the original buffer. |
164 | if (new_buffer_var.same_as(buf->data)) { |
165 | return buf; |
166 | } |
167 | |
168 | // If the current scope already has a mapping of this buffer, use |
169 | // the mapped buffer. |
170 | auto key = buf.get(); |
171 | std::vector<Buffer>& buffers = buf_remap_[key]; |
172 | if (buffers.size() && buffers.back()->data.same_as(new_buffer_var)) { |
173 | return buffers.back(); |
174 | } |
175 | |
176 | // Otherwise, make and return a new buffer object that uses the |
177 | // new buffer, pushing it onto the scoped stack of existing |
178 | // buffers. This will be popped when the new_buffer_var |
179 | // redefinition is popped. |
180 | Buffer new_buf(new_buffer_var, buf->dtype, buf->shape, buf->strides, buf->elem_offset, |
181 | buf->name, buf->data_alignment, buf->offset_factor, buf->buffer_type, |
182 | buf->axis_separators, buf->span); |
183 | buffers.push_back(new_buf); |
184 | return new_buf; |
185 | } |
186 | |
187 | Stmt VisitStmt_(const LetStmtNode* op) final { |
188 | const Var& v = op->var; |
189 | if (defined_.count(v.get())) { |
190 | PrimExpr value = this->VisitExpr(op->value); |
191 | ScopedRedefine redefine(this, v); |
192 | Stmt body = this->VisitStmt(op->body); |
193 | return LetStmt(redefine.new_var, value, body); |
194 | } else { |
195 | defined_.insert(v.get()); |
196 | return StmtExprMutator::VisitStmt_(op); |
197 | } |
198 | } |
199 | Stmt VisitStmt_(const ForNode* op) final { |
200 | const Var& v = op->loop_var; |
201 | if (defined_.count(v.get())) { |
202 | ScopedRedefine redefine(this, v); |
203 | Stmt stmt = StmtExprMutator::VisitStmt_(op); |
204 | op = stmt.as<ForNode>(); |
205 | return For(redefine.new_var, op->min, op->extent, op->kind, op->body, op->thread_binding, |
206 | op->annotations); |
207 | } else { |
208 | defined_.insert(v.get()); |
209 | return StmtExprMutator::VisitStmt_(op); |
210 | } |
211 | } |
212 | Stmt VisitStmt_(const AllocateNode* op) final { |
213 | const Var& v = op->buffer_var; |
214 | if (defined_.count(v.get())) { |
215 | ScopedRedefine redefine(this, v); |
216 | Stmt stmt = StmtExprMutator::VisitStmt_(op); |
217 | op = stmt.as<AllocateNode>(); |
218 | return Allocate(redefine.new_var, op->dtype, op->extents, op->condition, op->body); |
219 | } else { |
220 | defined_.insert(v.get()); |
221 | return StmtExprMutator::VisitStmt_(op); |
222 | } |
223 | } |
224 | Stmt VisitStmt_(const AttrStmtNode* op) final { |
225 | if (const VarNode* v = op->node.as<VarNode>()) { |
226 | Stmt stmt = StmtExprMutator::VisitStmt_(op); |
227 | op = stmt.as<AttrStmtNode>(); |
228 | if (scope_.count(v) && scope_[v].size() != 0) { |
229 | return AttrStmt(scope_[v].back(), op->attr_key, op->value, op->body); |
230 | } else { |
231 | return stmt; |
232 | } |
233 | } else { |
234 | return StmtExprMutator::VisitStmt_(op); |
235 | } |
236 | } |
237 | |
238 | private: |
239 | struct ScopedRedefine { |
240 | ScopedRedefine(IRConvertSSA* parent, Var old_var) : parent(parent), old_var(old_var) { |
241 | if (old_var->type_annotation.defined()) { |
242 | new_var = Var(old_var->name_hint, old_var->type_annotation); |
243 | } else { |
244 | new_var = Var(old_var->name_hint, old_var->dtype); |
245 | } |
246 | parent->scope_[old_var.get()].push_back(new_var); |
247 | } |
248 | |
249 | ~ScopedRedefine() { |
250 | parent->scope_[old_var.get()].pop_back(); |
251 | for (auto& kv : parent->buf_remap_) { |
252 | std::vector<Buffer>& buffers = kv.second; |
253 | if (buffers.size() && (buffers.back()->data.get() == new_var.get())) { |
254 | buffers.pop_back(); |
255 | } |
256 | } |
257 | } |
258 | |
259 | IRConvertSSA* parent; |
260 | Var old_var; |
261 | Var new_var; |
262 | }; |
263 | |
264 | std::unordered_map<const VarNode*, std::vector<Var>> scope_; |
265 | std::unordered_set<const VarNode*> defined_; |
266 | std::unordered_map<const BufferNode*, std::vector<Buffer>> buf_remap_; |
267 | }; |
268 | |
269 | Stmt ConvertSSA(Stmt stmt) { return IRConvertSSA()(std::move(stmt)); } |
270 | |
271 | String GetPtrStorageScope(Var buffer_var) { |
272 | const auto* ptr_type = buffer_var->type_annotation.as<PointerTypeNode>(); |
273 | ICHECK(ptr_type) << "The provided variable is not of pointer type" ; |
274 | return ptr_type->storage_scope; |
275 | } |
276 | |
277 | Array<PrimExpr> ConvertIndices(const MatchBufferRegion& match_buffer, |
278 | const Array<PrimExpr>& indices) { |
279 | const Buffer& target = match_buffer->buffer; |
280 | const BufferRegion& source = match_buffer->source; |
281 | ICHECK_EQ(indices.size(), target->shape.size()); |
282 | |
283 | arith::Analyzer analyzer; |
284 | Array<PrimExpr> result; |
285 | result.reserve(source->region.size()); |
286 | size_t offset = source->region.size() - indices.size(); |
287 | for (size_t i = 0; i < offset; ++i) { |
288 | const Range& range = source->region[i]; |
289 | ICHECK(analyzer.CanProve(range->extent == 1)); |
290 | result.push_back(range->min); |
291 | } |
292 | for (size_t i = 0; i < indices.size(); ++i) { |
293 | const Range& range = source->region[i + offset]; |
294 | const PrimExpr& index = indices[i]; |
295 | result.push_back(range->min + index); |
296 | } |
297 | return result; |
298 | } |
299 | |
300 | Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region& region) { |
301 | const Buffer& target = match_buffer->buffer; |
302 | const BufferRegion& source = match_buffer->source; |
303 | ICHECK_EQ(region.size(), target->shape.size()); |
304 | |
305 | arith::Analyzer analyzer; |
306 | Region result; |
307 | result.reserve(source->region.size()); |
308 | size_t offset = source->region.size() - region.size(); |
309 | for (size_t i = 0; i < offset; ++i) { |
310 | const Range& source_range = source->region[i]; |
311 | ICHECK(analyzer.CanProve(source_range->extent == 1)); |
312 | result.push_back(Range::FromMinExtent(source_range->min, 1)); |
313 | } |
314 | for (size_t i = 0; i < region.size(); ++i) { |
315 | const Range& source_range = source->region[i + offset]; |
316 | const Range& target_range = region[i]; |
317 | result.push_back( |
318 | Range::FromMinExtent(source_range->min + target_range->min, target_range->extent)); |
319 | } |
320 | return result; |
321 | } |
322 | |
323 | Bool IsFromLegacyTESchedule(PrimFunc f) { |
324 | Optional<Bool> from_legacy_te_schedule = f->GetAttr("from_legacy_te_schedule" , Bool(false)); |
325 | return from_legacy_te_schedule.value(); |
326 | } |
327 | |
328 | Map<Var, Range> ConditionalBoundsContext::GetVarBoundsFromCondition() { |
329 | // extract equations and related vars from condition expression. |
330 | // currently only extract simple integral equations which could be solvable. |
331 | arith::Analyzer analyzer; |
332 | PrimExpr condition = is_true_branch_ ? condition_ : analyzer.Simplify(!condition_); |
333 | Array<PrimExpr> equations; |
334 | Array<Var> vars; |
335 | std::function<void(const PrimExpr&)> fvisit = [&equations, &vars, &fvisit](const PrimExpr& e) { |
336 | if (e->IsInstance<GENode>() || e->IsInstance<GTNode>() || e->IsInstance<LENode>() || |
337 | e->IsInstance<LTNode>() || e->IsInstance<EQNode>() || e->IsInstance<NENode>()) { |
338 | bool is_simple = true; |
339 | std::vector<Var> cand_vars; |
340 | PostOrderVisit(e, [&cand_vars, &is_simple, &e](const ObjectRef& obj) { |
341 | if (obj.same_as(e)) { |
342 | return; |
343 | } else if (const VarNode* var = obj.as<VarNode>()) { |
344 | if (var->dtype.is_int() || var->dtype.is_uint()) { |
345 | cand_vars.push_back(GetRef<Var>(var)); |
346 | } |
347 | } else { |
348 | is_simple &= obj->IsInstance<AddNode>() || obj->IsInstance<SubNode>() || |
349 | obj->IsInstance<MulNode>() || obj->IsInstance<FloorDivNode>() || |
350 | obj->IsInstance<FloorModNode>() || obj->IsInstance<IntImmNode>(); |
351 | } |
352 | }); |
353 | if (is_simple && !cand_vars.empty()) { |
354 | for (const Var& new_var : cand_vars) { |
355 | if (!std::any_of(vars.begin(), vars.end(), |
356 | [&new_var](const Var& v) { return v.same_as(new_var); })) { |
357 | vars.push_back(new_var); |
358 | } |
359 | } |
360 | equations.push_back(Downcast<PrimExpr>(e)); |
361 | } |
362 | } else if (e->IsInstance<AndNode>()) { |
363 | And op = Downcast<And>(e); |
364 | fvisit(op->a); |
365 | fvisit(op->b); |
366 | } else if (e->IsInstance<CallNode>()) { |
367 | Call op = Downcast<Call>(e); |
368 | if (op->op.same_as(builtin::likely())) { |
369 | fvisit(op->args[0]); |
370 | } |
371 | } |
372 | }; |
373 | fvisit(condition); |
374 | if (equations.empty() || vars.empty()) { |
375 | return Map<Var, Range>(); |
376 | } |
377 | // build dom ranges for related vars |
378 | Map<Var, Range> ranges; |
379 | for (const Var& v : vars) { |
380 | arith::IntSet dom; |
381 | auto relax_it = relax_map_->find(v.get()); |
382 | if (relax_it != relax_map_->end()) { |
383 | dom = relax_it->second; |
384 | } else { |
385 | auto hint_it = hint_map_->find(v.get()); |
386 | if (hint_it != hint_map_->end()) { |
387 | dom = hint_it->second; |
388 | } |
389 | } |
390 | if (dom.defined()) { |
391 | ranges.Set(v, Range::FromMinExtent(dom.min(), analyzer.Simplify(dom.max() - dom.min() + 1))); |
392 | } |
393 | } |
394 | // solve constraints |
395 | arith::IntConstraints constraint(vars, ranges, equations); |
396 | auto result = arith::SolveInequalitiesToRange(constraint); |
397 | return result->ranges; |
398 | } |
399 | |
400 | ConditionalBoundsContext::ConditionalBoundsContext( |
401 | const PrimExpr& condition, std::unordered_map<const VarNode*, arith::IntSet>* relax_map, |
402 | std::unordered_map<const VarNode*, arith::IntSet>* hint_map, bool is_true_branch) |
403 | : condition_(condition), |
404 | relax_map_(relax_map), |
405 | hint_map_(hint_map), |
406 | is_true_branch_(is_true_branch) {} |
407 | |
408 | void ConditionalBoundsContext::EnterWithScope() { |
409 | for (const auto& p : GetVarBoundsFromCondition()) { |
410 | const auto* var = p.first.get(); |
411 | arith::IntSet new_dom = arith::IntSet::FromRange(p.second); |
412 | auto relax_it = relax_map_->find(var); |
413 | if (relax_it != relax_map_->end()) { |
414 | // this is a bound for relaxed var |
415 | origin_map_.emplace(var, relax_it->second); |
416 | relax_it->second = arith::Intersect({relax_it->second, new_dom}); |
417 | } else { |
418 | // this is a bound for free var |
419 | auto hint_it = hint_map_->find(var); |
420 | if (hint_it != hint_map_->end()) { |
421 | origin_map_.emplace(var, hint_it->second); |
422 | hint_it->second = arith::Intersect({hint_it->second, new_dom}); |
423 | } else { |
424 | origin_map_.emplace(var, arith::IntSet::Nothing()); |
425 | hint_map_->insert(hint_it, {var, new_dom}); |
426 | } |
427 | } |
428 | } |
429 | } |
430 | |
431 | void ConditionalBoundsContext::ExitWithScope() { |
432 | for (const auto& p : origin_map_) { |
433 | const auto* var = p.first; |
434 | auto relax_it = relax_map_->find(var); |
435 | if (relax_it != relax_map_->end()) { |
436 | // recover bound for relaxed var |
437 | relax_it->second = p.second; |
438 | } else { |
439 | // recover bound for free var |
440 | auto hint_it = hint_map_->find(var); |
441 | ICHECK(hint_it != hint_map_->end()); |
442 | if (p.second.IsNothing()) { |
443 | hint_map_->erase(hint_it); |
444 | } else { |
445 | hint_it->second = p.second; |
446 | } |
447 | } |
448 | } |
449 | } |
450 | |
451 | std::pair<PrimExpr, PrimExpr> GetAsyncWaitAttributes(const AttrStmtNode* op) { |
452 | ICHECK(op && op->attr_key == tir::attr::async_wait_queue_scope); |
453 | auto inner = op->body.as<AttrStmtNode>(); |
454 | ICHECK(inner && inner->attr_key == tir::attr::async_wait_inflight_count); |
455 | return std::make_pair(op->value, inner->value); |
456 | } |
457 | |
458 | } // namespace tir |
459 | } // namespace tvm |
460 | |