1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file ir_utils.cc
22 * \brief Helper functions to construct and compose IR nodes.
23 */
24#include "ir_utils.h"
25
26#include <tvm/arith/analyzer.h>
27#include <tvm/arith/int_solver.h>
28#include <tvm/tir/stmt_functor.h>
29
30#include <unordered_map>
31#include <unordered_set>
32#include <utility>
33
34namespace tvm {
35namespace tir {
36
37Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) {
38 // use reverse iteration
39 for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) {
40 Stmt s = *ri;
41 if (const auto* for_ = s.as<ForNode>()) {
42 auto n = make_object<ForNode>(*for_);
43 ICHECK(is_no_op(n->body));
44 n->body = body;
45 body = Stmt(n);
46 } else if (const auto* let = s.as<LetStmtNode>()) {
47 auto n = make_object<LetStmtNode>(*let);
48 ICHECK(is_no_op(n->body));
49 n->body = body;
50 body = Stmt(n);
51 } else if (const auto* attr = s.as<AttrStmtNode>()) {
52 auto n = make_object<AttrStmtNode>(*attr);
53 ICHECK(is_no_op(n->body));
54 n->body = body;
55 body = Stmt(n);
56 } else if (const auto* ite = s.as<IfThenElseNode>()) {
57 auto n = make_object<IfThenElseNode>(*ite);
58 ICHECK(is_no_op(n->then_case));
59 ICHECK(!n->else_case);
60 n->then_case = body;
61 body = Stmt(n);
62 } else if (const auto* seq = s.as<SeqStmtNode>()) {
63 auto n = make_object<SeqStmtNode>(*seq);
64 ICHECK(n->size() != 0 && is_no_op(n->seq[n->size() - 1]));
65 n->seq.Set(n->size() - 1, body);
66 body = Stmt(n);
67 } else if (const auto* assert_ = s.as<AssertStmtNode>()) {
68 auto n = make_object<AssertStmtNode>(*assert_);
69 ICHECK(is_no_op(n->body));
70 n->body = body;
71 body = Stmt(n);
72 } else if (const auto* alloc = s.as<AllocateNode>()) {
73 auto n = make_object<AllocateNode>(*alloc);
74 ICHECK(is_no_op(n->body));
75 n->body = body;
76 body = Stmt(n);
77 } else {
78 LOG(FATAL) << "not supported nest type";
79 }
80 }
81 return body;
82}
83
84Stmt MergeNest(const std::vector<std::vector<Stmt>>& nest, Stmt body) {
85 for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) {
86 body = MergeNest(*ri, body);
87 }
88 return body;
89}
90
91class IRConvertSSA final : public StmtExprMutator {
92 public:
93 PrimExpr VisitExpr_(const VarNode* op) final {
94 if (scope_.count(op) && !scope_[op].empty()) {
95 return scope_[op].back();
96 } else {
97 return GetRef<PrimExpr>(op);
98 }
99 }
100 PrimExpr VisitExpr_(const LetNode* op) final {
101 const Var& v = op->var;
102 if (defined_.count(v.get())) {
103 PrimExpr value = this->VisitExpr(op->value);
104 ScopedRedefine redefine(this, v);
105 PrimExpr body = this->VisitExpr(op->body);
106 return Let(redefine.new_var, value, body);
107 } else {
108 defined_.insert(v.get());
109 return StmtExprMutator::VisitExpr_(op);
110 }
111 }
112
113 PrimExpr VisitExpr_(const LoadNode* op) final {
114 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
115 }
116
117 Stmt VisitStmt_(const StoreNode* op) final {
118 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
119 }
120
121 PrimExpr VisitExpr_(const BufferLoadNode* op) final {
122 auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
123 auto output = VisitBufferAccess(std::move(node));
124 return std::move(output);
125 }
126
127 Stmt VisitStmt_(const BufferStoreNode* op) final {
128 auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
129 auto output = VisitBufferAccess(std::move(node));
130 return std::move(output);
131 }
132
133 Stmt VisitStmt_(const DeclBufferNode* op) final {
134 DeclBuffer decl = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));
135 Buffer new_buffer = GetRemappedBuffer(decl->buffer);
136 if (!new_buffer.same_as(decl->buffer)) {
137 decl.CopyOnWrite()->buffer = std::move(new_buffer);
138 }
139 return std::move(decl);
140 }
141
142 template <typename Node>
143 Node VisitBufferAccess(Node node) {
144 Buffer new_buf = GetRemappedBuffer(node->buffer);
145 if (!new_buf.same_as(node->buffer)) {
146 auto writer = node.CopyOnWrite();
147 writer->buffer = new_buf;
148 }
149
150 return node;
151 }
152
153 Buffer GetRemappedBuffer(Buffer buf) {
154 // Determine the buffer var that should be in the updated buffer,
155 // given the current scope. If no redefines are present, then the
156 // buffer var is unchanged.
157 Var new_buffer_var = buf->data;
158 auto var_it = scope_.find(buf->data.get());
159 if (var_it != scope_.end() && !var_it->second.empty()) {
160 new_buffer_var = var_it->second.back();
161 }
162
163 // If no mapping is required, return the original buffer.
164 if (new_buffer_var.same_as(buf->data)) {
165 return buf;
166 }
167
168 // If the current scope already has a mapping of this buffer, use
169 // the mapped buffer.
170 auto key = buf.get();
171 std::vector<Buffer>& buffers = buf_remap_[key];
172 if (buffers.size() && buffers.back()->data.same_as(new_buffer_var)) {
173 return buffers.back();
174 }
175
176 // Otherwise, make and return a new buffer object that uses the
177 // new buffer, pushing it onto the scoped stack of existing
178 // buffers. This will be popped when the new_buffer_var
179 // redefinition is popped.
180 Buffer new_buf(new_buffer_var, buf->dtype, buf->shape, buf->strides, buf->elem_offset,
181 buf->name, buf->data_alignment, buf->offset_factor, buf->buffer_type,
182 buf->axis_separators, buf->span);
183 buffers.push_back(new_buf);
184 return new_buf;
185 }
186
187 Stmt VisitStmt_(const LetStmtNode* op) final {
188 const Var& v = op->var;
189 if (defined_.count(v.get())) {
190 PrimExpr value = this->VisitExpr(op->value);
191 ScopedRedefine redefine(this, v);
192 Stmt body = this->VisitStmt(op->body);
193 return LetStmt(redefine.new_var, value, body);
194 } else {
195 defined_.insert(v.get());
196 return StmtExprMutator::VisitStmt_(op);
197 }
198 }
199 Stmt VisitStmt_(const ForNode* op) final {
200 const Var& v = op->loop_var;
201 if (defined_.count(v.get())) {
202 ScopedRedefine redefine(this, v);
203 Stmt stmt = StmtExprMutator::VisitStmt_(op);
204 op = stmt.as<ForNode>();
205 return For(redefine.new_var, op->min, op->extent, op->kind, op->body, op->thread_binding,
206 op->annotations);
207 } else {
208 defined_.insert(v.get());
209 return StmtExprMutator::VisitStmt_(op);
210 }
211 }
212 Stmt VisitStmt_(const AllocateNode* op) final {
213 const Var& v = op->buffer_var;
214 if (defined_.count(v.get())) {
215 ScopedRedefine redefine(this, v);
216 Stmt stmt = StmtExprMutator::VisitStmt_(op);
217 op = stmt.as<AllocateNode>();
218 return Allocate(redefine.new_var, op->dtype, op->extents, op->condition, op->body);
219 } else {
220 defined_.insert(v.get());
221 return StmtExprMutator::VisitStmt_(op);
222 }
223 }
224 Stmt VisitStmt_(const AttrStmtNode* op) final {
225 if (const VarNode* v = op->node.as<VarNode>()) {
226 Stmt stmt = StmtExprMutator::VisitStmt_(op);
227 op = stmt.as<AttrStmtNode>();
228 if (scope_.count(v) && scope_[v].size() != 0) {
229 return AttrStmt(scope_[v].back(), op->attr_key, op->value, op->body);
230 } else {
231 return stmt;
232 }
233 } else {
234 return StmtExprMutator::VisitStmt_(op);
235 }
236 }
237
238 private:
239 struct ScopedRedefine {
240 ScopedRedefine(IRConvertSSA* parent, Var old_var) : parent(parent), old_var(old_var) {
241 if (old_var->type_annotation.defined()) {
242 new_var = Var(old_var->name_hint, old_var->type_annotation);
243 } else {
244 new_var = Var(old_var->name_hint, old_var->dtype);
245 }
246 parent->scope_[old_var.get()].push_back(new_var);
247 }
248
249 ~ScopedRedefine() {
250 parent->scope_[old_var.get()].pop_back();
251 for (auto& kv : parent->buf_remap_) {
252 std::vector<Buffer>& buffers = kv.second;
253 if (buffers.size() && (buffers.back()->data.get() == new_var.get())) {
254 buffers.pop_back();
255 }
256 }
257 }
258
259 IRConvertSSA* parent;
260 Var old_var;
261 Var new_var;
262 };
263
264 std::unordered_map<const VarNode*, std::vector<Var>> scope_;
265 std::unordered_set<const VarNode*> defined_;
266 std::unordered_map<const BufferNode*, std::vector<Buffer>> buf_remap_;
267};
268
269Stmt ConvertSSA(Stmt stmt) { return IRConvertSSA()(std::move(stmt)); }
270
271String GetPtrStorageScope(Var buffer_var) {
272 const auto* ptr_type = buffer_var->type_annotation.as<PointerTypeNode>();
273 ICHECK(ptr_type) << "The provided variable is not of pointer type";
274 return ptr_type->storage_scope;
275}
276
277Array<PrimExpr> ConvertIndices(const MatchBufferRegion& match_buffer,
278 const Array<PrimExpr>& indices) {
279 const Buffer& target = match_buffer->buffer;
280 const BufferRegion& source = match_buffer->source;
281 ICHECK_EQ(indices.size(), target->shape.size());
282
283 arith::Analyzer analyzer;
284 Array<PrimExpr> result;
285 result.reserve(source->region.size());
286 size_t offset = source->region.size() - indices.size();
287 for (size_t i = 0; i < offset; ++i) {
288 const Range& range = source->region[i];
289 ICHECK(analyzer.CanProve(range->extent == 1));
290 result.push_back(range->min);
291 }
292 for (size_t i = 0; i < indices.size(); ++i) {
293 const Range& range = source->region[i + offset];
294 const PrimExpr& index = indices[i];
295 result.push_back(range->min + index);
296 }
297 return result;
298}
299
300Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region& region) {
301 const Buffer& target = match_buffer->buffer;
302 const BufferRegion& source = match_buffer->source;
303 ICHECK_EQ(region.size(), target->shape.size());
304
305 arith::Analyzer analyzer;
306 Region result;
307 result.reserve(source->region.size());
308 size_t offset = source->region.size() - region.size();
309 for (size_t i = 0; i < offset; ++i) {
310 const Range& source_range = source->region[i];
311 ICHECK(analyzer.CanProve(source_range->extent == 1));
312 result.push_back(Range::FromMinExtent(source_range->min, 1));
313 }
314 for (size_t i = 0; i < region.size(); ++i) {
315 const Range& source_range = source->region[i + offset];
316 const Range& target_range = region[i];
317 result.push_back(
318 Range::FromMinExtent(source_range->min + target_range->min, target_range->extent));
319 }
320 return result;
321}
322
323Bool IsFromLegacyTESchedule(PrimFunc f) {
324 Optional<Bool> from_legacy_te_schedule = f->GetAttr("from_legacy_te_schedule", Bool(false));
325 return from_legacy_te_schedule.value();
326}
327
328Map<Var, Range> ConditionalBoundsContext::GetVarBoundsFromCondition() {
329 // extract equations and related vars from condition expression.
330 // currently only extract simple integral equations which could be solvable.
331 arith::Analyzer analyzer;
332 PrimExpr condition = is_true_branch_ ? condition_ : analyzer.Simplify(!condition_);
333 Array<PrimExpr> equations;
334 Array<Var> vars;
335 std::function<void(const PrimExpr&)> fvisit = [&equations, &vars, &fvisit](const PrimExpr& e) {
336 if (e->IsInstance<GENode>() || e->IsInstance<GTNode>() || e->IsInstance<LENode>() ||
337 e->IsInstance<LTNode>() || e->IsInstance<EQNode>() || e->IsInstance<NENode>()) {
338 bool is_simple = true;
339 std::vector<Var> cand_vars;
340 PostOrderVisit(e, [&cand_vars, &is_simple, &e](const ObjectRef& obj) {
341 if (obj.same_as(e)) {
342 return;
343 } else if (const VarNode* var = obj.as<VarNode>()) {
344 if (var->dtype.is_int() || var->dtype.is_uint()) {
345 cand_vars.push_back(GetRef<Var>(var));
346 }
347 } else {
348 is_simple &= obj->IsInstance<AddNode>() || obj->IsInstance<SubNode>() ||
349 obj->IsInstance<MulNode>() || obj->IsInstance<FloorDivNode>() ||
350 obj->IsInstance<FloorModNode>() || obj->IsInstance<IntImmNode>();
351 }
352 });
353 if (is_simple && !cand_vars.empty()) {
354 for (const Var& new_var : cand_vars) {
355 if (!std::any_of(vars.begin(), vars.end(),
356 [&new_var](const Var& v) { return v.same_as(new_var); })) {
357 vars.push_back(new_var);
358 }
359 }
360 equations.push_back(Downcast<PrimExpr>(e));
361 }
362 } else if (e->IsInstance<AndNode>()) {
363 And op = Downcast<And>(e);
364 fvisit(op->a);
365 fvisit(op->b);
366 } else if (e->IsInstance<CallNode>()) {
367 Call op = Downcast<Call>(e);
368 if (op->op.same_as(builtin::likely())) {
369 fvisit(op->args[0]);
370 }
371 }
372 };
373 fvisit(condition);
374 if (equations.empty() || vars.empty()) {
375 return Map<Var, Range>();
376 }
377 // build dom ranges for related vars
378 Map<Var, Range> ranges;
379 for (const Var& v : vars) {
380 arith::IntSet dom;
381 auto relax_it = relax_map_->find(v.get());
382 if (relax_it != relax_map_->end()) {
383 dom = relax_it->second;
384 } else {
385 auto hint_it = hint_map_->find(v.get());
386 if (hint_it != hint_map_->end()) {
387 dom = hint_it->second;
388 }
389 }
390 if (dom.defined()) {
391 ranges.Set(v, Range::FromMinExtent(dom.min(), analyzer.Simplify(dom.max() - dom.min() + 1)));
392 }
393 }
394 // solve constraints
395 arith::IntConstraints constraint(vars, ranges, equations);
396 auto result = arith::SolveInequalitiesToRange(constraint);
397 return result->ranges;
398}
399
400ConditionalBoundsContext::ConditionalBoundsContext(
401 const PrimExpr& condition, std::unordered_map<const VarNode*, arith::IntSet>* relax_map,
402 std::unordered_map<const VarNode*, arith::IntSet>* hint_map, bool is_true_branch)
403 : condition_(condition),
404 relax_map_(relax_map),
405 hint_map_(hint_map),
406 is_true_branch_(is_true_branch) {}
407
408void ConditionalBoundsContext::EnterWithScope() {
409 for (const auto& p : GetVarBoundsFromCondition()) {
410 const auto* var = p.first.get();
411 arith::IntSet new_dom = arith::IntSet::FromRange(p.second);
412 auto relax_it = relax_map_->find(var);
413 if (relax_it != relax_map_->end()) {
414 // this is a bound for relaxed var
415 origin_map_.emplace(var, relax_it->second);
416 relax_it->second = arith::Intersect({relax_it->second, new_dom});
417 } else {
418 // this is a bound for free var
419 auto hint_it = hint_map_->find(var);
420 if (hint_it != hint_map_->end()) {
421 origin_map_.emplace(var, hint_it->second);
422 hint_it->second = arith::Intersect({hint_it->second, new_dom});
423 } else {
424 origin_map_.emplace(var, arith::IntSet::Nothing());
425 hint_map_->insert(hint_it, {var, new_dom});
426 }
427 }
428 }
429}
430
431void ConditionalBoundsContext::ExitWithScope() {
432 for (const auto& p : origin_map_) {
433 const auto* var = p.first;
434 auto relax_it = relax_map_->find(var);
435 if (relax_it != relax_map_->end()) {
436 // recover bound for relaxed var
437 relax_it->second = p.second;
438 } else {
439 // recover bound for free var
440 auto hint_it = hint_map_->find(var);
441 ICHECK(hint_it != hint_map_->end());
442 if (p.second.IsNothing()) {
443 hint_map_->erase(hint_it);
444 } else {
445 hint_it->second = p.second;
446 }
447 }
448 }
449}
450
451std::pair<PrimExpr, PrimExpr> GetAsyncWaitAttributes(const AttrStmtNode* op) {
452 ICHECK(op && op->attr_key == tir::attr::async_wait_queue_scope);
453 auto inner = op->body.as<AttrStmtNode>();
454 ICHECK(inner && inner->attr_key == tir::attr::async_wait_inflight_count);
455 return std::make_pair(op->value, inner->value);
456}
457
458} // namespace tir
459} // namespace tvm
460