1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/arith/ir_visitor_with_analyzer.cc
22 */
23#include "ir_visitor_with_analyzer.h"
24
25#include <tvm/tir/analysis.h>
26#include <tvm/tir/builtin.h>
27#include <tvm/tir/op.h>
28
29namespace tvm {
30namespace arith {
31
32using namespace tir;
33
34void IRVisitorWithAnalyzer::VisitStmt_(const ForNode* op) {
35 analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
36 StmtExprVisitor::VisitStmt_(op);
37}
38
39void IRVisitorWithAnalyzer::VisitStmt_(const BlockNode* op) {
40 for (const auto& iter_var : op->iter_vars) {
41 analyzer_.Bind(iter_var->var, iter_var->dom);
42 }
43 StmtExprVisitor::VisitStmt_(op);
44}
45
46void IRVisitorWithAnalyzer::VisitStmt_(const LetStmtNode* op) {
47 this->VisitExpr(op->value);
48 analyzer_.Bind(op->var, op->value);
49 this->VisitStmt(op->body);
50}
51
52void IRVisitorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) {
53 this->VisitExpr(op->condition);
54
55 PrimExpr real_condition = ExtractRealCondition(op->condition);
56
57 {
58 With<ConstraintContext> constraint(&analyzer_, real_condition);
59 this->VisitStmt(op->then_case);
60 }
61 if (op->else_case) {
62 With<ConstraintContext> constraint(&analyzer_, analyzer_.rewrite_simplify(Not(real_condition)));
63 this->VisitStmt(op->else_case.value());
64 }
65}
66
67void IRVisitorWithAnalyzer::VisitStmt_(const AttrStmtNode* op) {
68 if (op->attr_key == tir::attr::thread_extent || op->attr_key == tir::attr::virtual_thread) {
69 IterVar iv = Downcast<IterVar>(op->node);
70 ICHECK_NE(iv->thread_tag.length(), 0U);
71 analyzer_.Bind(iv->var, Range::FromMinExtent(0, op->value));
72 }
73 StmtExprVisitor::VisitStmt_(op);
74}
75
76void IRVisitorWithAnalyzer::VisitStmt_(const AssertStmtNode* op) {
77 this->VisitExpr(op->condition);
78 this->VisitExpr(op->message);
79 With<ConstraintContext> constraint(&analyzer_, op->condition);
80 this->VisitStmt(op->body);
81}
82
83void IRVisitorWithAnalyzer::VisitExpr_(const CallNode* op) {
84 // add condition context to if_then_else
85 static auto op_if_then_else = Op::Get("tir.if_then_else");
86 if (op->op.same_as(op_if_then_else)) {
87 PrimExpr cond = op->args[0];
88 this->VisitExpr(op->args[0]);
89 {
90 With<ConstraintContext> constraint(&analyzer_, cond);
91 this->VisitExpr(op->args[1]);
92 }
93 {
94 With<ConstraintContext> constraint(&analyzer_, analyzer_.rewrite_simplify(Not(cond)));
95 this->VisitExpr(op->args[2]);
96 }
97 } else {
98 StmtExprVisitor::VisitExpr_(op);
99 }
100}
101
102void IRVisitorWithAnalyzer::VisitExpr_(const LetNode* op) {
103 this->VisitExpr(op->value);
104 analyzer_.Bind(op->var, op->value);
105 this->VisitExpr(op->body);
106}
107
108void IRVisitorWithAnalyzer::VisitExpr_(const ReduceNode* op) {
109 for (const IterVar& iv : op->axis) {
110 analyzer_.Bind(iv->var, iv->dom);
111 }
112 StmtExprVisitor::VisitExpr_(op);
113}
114
115PrimExpr IRVisitorWithAnalyzer::ExtractRealCondition(PrimExpr condition) const {
116 if (auto call = condition.as<CallNode>()) {
117 if (call->op.same_as(builtin::likely())) {
118 return call->args[0];
119 }
120 }
121
122 return condition;
123}
124
125} // namespace arith
126} // namespace tvm
127