1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/arith/ir_visitor_with_analyzer.cc |
22 | */ |
23 | #include "ir_visitor_with_analyzer.h" |
24 | |
25 | #include <tvm/tir/analysis.h> |
26 | #include <tvm/tir/builtin.h> |
27 | #include <tvm/tir/op.h> |
28 | |
29 | namespace tvm { |
30 | namespace arith { |
31 | |
32 | using namespace tir; |
33 | |
34 | void IRVisitorWithAnalyzer::VisitStmt_(const ForNode* op) { |
35 | analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); |
36 | StmtExprVisitor::VisitStmt_(op); |
37 | } |
38 | |
39 | void IRVisitorWithAnalyzer::VisitStmt_(const BlockNode* op) { |
40 | for (const auto& iter_var : op->iter_vars) { |
41 | analyzer_.Bind(iter_var->var, iter_var->dom); |
42 | } |
43 | StmtExprVisitor::VisitStmt_(op); |
44 | } |
45 | |
46 | void IRVisitorWithAnalyzer::VisitStmt_(const LetStmtNode* op) { |
47 | this->VisitExpr(op->value); |
48 | analyzer_.Bind(op->var, op->value); |
49 | this->VisitStmt(op->body); |
50 | } |
51 | |
52 | void IRVisitorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) { |
53 | this->VisitExpr(op->condition); |
54 | |
55 | PrimExpr real_condition = ExtractRealCondition(op->condition); |
56 | |
57 | { |
58 | With<ConstraintContext> constraint(&analyzer_, real_condition); |
59 | this->VisitStmt(op->then_case); |
60 | } |
61 | if (op->else_case) { |
62 | With<ConstraintContext> constraint(&analyzer_, analyzer_.rewrite_simplify(Not(real_condition))); |
63 | this->VisitStmt(op->else_case.value()); |
64 | } |
65 | } |
66 | |
67 | void IRVisitorWithAnalyzer::VisitStmt_(const AttrStmtNode* op) { |
68 | if (op->attr_key == tir::attr::thread_extent || op->attr_key == tir::attr::virtual_thread) { |
69 | IterVar iv = Downcast<IterVar>(op->node); |
70 | ICHECK_NE(iv->thread_tag.length(), 0U); |
71 | analyzer_.Bind(iv->var, Range::FromMinExtent(0, op->value)); |
72 | } |
73 | StmtExprVisitor::VisitStmt_(op); |
74 | } |
75 | |
76 | void IRVisitorWithAnalyzer::VisitStmt_(const AssertStmtNode* op) { |
77 | this->VisitExpr(op->condition); |
78 | this->VisitExpr(op->message); |
79 | With<ConstraintContext> constraint(&analyzer_, op->condition); |
80 | this->VisitStmt(op->body); |
81 | } |
82 | |
83 | void IRVisitorWithAnalyzer::VisitExpr_(const CallNode* op) { |
84 | // add condition context to if_then_else |
85 | static auto op_if_then_else = Op::Get("tir.if_then_else" ); |
86 | if (op->op.same_as(op_if_then_else)) { |
87 | PrimExpr cond = op->args[0]; |
88 | this->VisitExpr(op->args[0]); |
89 | { |
90 | With<ConstraintContext> constraint(&analyzer_, cond); |
91 | this->VisitExpr(op->args[1]); |
92 | } |
93 | { |
94 | With<ConstraintContext> constraint(&analyzer_, analyzer_.rewrite_simplify(Not(cond))); |
95 | this->VisitExpr(op->args[2]); |
96 | } |
97 | } else { |
98 | StmtExprVisitor::VisitExpr_(op); |
99 | } |
100 | } |
101 | |
102 | void IRVisitorWithAnalyzer::VisitExpr_(const LetNode* op) { |
103 | this->VisitExpr(op->value); |
104 | analyzer_.Bind(op->var, op->value); |
105 | this->VisitExpr(op->body); |
106 | } |
107 | |
108 | void IRVisitorWithAnalyzer::VisitExpr_(const ReduceNode* op) { |
109 | for (const IterVar& iv : op->axis) { |
110 | analyzer_.Bind(iv->var, iv->dom); |
111 | } |
112 | StmtExprVisitor::VisitExpr_(op); |
113 | } |
114 | |
115 | PrimExpr IRVisitorWithAnalyzer::(PrimExpr condition) const { |
116 | if (auto call = condition.as<CallNode>()) { |
117 | if (call->op.same_as(builtin::likely())) { |
118 | return call->args[0]; |
119 | } |
120 | } |
121 | |
122 | return condition; |
123 | } |
124 | |
125 | } // namespace arith |
126 | } // namespace tvm |
127 | |