1 | /* |
---|---|
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file var_touch.cc |
22 | * \brief Implementation of simple passes |
23 | */ |
24 | #include <tvm/tir/analysis.h> |
25 | #include <tvm/tir/stmt_functor.h> |
26 | |
27 | namespace tvm { |
28 | namespace tir { |
29 | |
30 | class VarTouchVisitor : public StmtExprVisitor { |
31 | public: |
32 | explicit VarTouchVisitor(std::function<bool(const VarNode*)> var_set) |
33 | : var_set_(std::move(var_set)) {} |
34 | |
35 | void VisitStmt(const Stmt& stmt) final { |
36 | if (use_var_) return; |
37 | StmtExprVisitor::VisitStmt(stmt); |
38 | } |
39 | |
40 | void VisitExpr(const PrimExpr& e) final { |
41 | if (use_var_) return; |
42 | StmtExprVisitor::VisitExpr(e); |
43 | } |
44 | |
45 | void VisitExpr_(const VarNode* op) final { Handle(op); } |
46 | |
47 | void VisitExpr_(const LoadNode* op) final { |
48 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; |
49 | } |
50 | |
51 | void VisitStmt_(const StoreNode* op) final { |
52 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; |
53 | } |
54 | |
55 | void VisitStmt_(const BufferStoreNode* op) final { |
56 | Handle(op->buffer->data.get()); |
57 | StmtVisitor::VisitStmt_(op); |
58 | } |
59 | |
60 | void VisitExpr_(const BufferLoadNode* op) final { |
61 | Handle(op->buffer->data.get()); |
62 | ExprVisitor::VisitExpr_(op); |
63 | } |
64 | |
65 | void Handle(const VarNode* var) { |
66 | if (var_set_(var)) use_var_ = true; |
67 | } |
68 | |
69 | bool use_var_{false}; |
70 | |
71 | private: |
72 | std::function<bool(const VarNode*)> var_set_; |
73 | }; |
74 | |
75 | bool UsesVar(const Stmt& stmt, std::function<bool(const VarNode*)> var_set) { |
76 | VarTouchVisitor visitor(std::move(var_set)); |
77 | visitor(stmt); |
78 | return visitor.use_var_; |
79 | } |
80 | |
81 | bool UsesVar(const PrimExpr& expr, std::function<bool(const VarNode*)> var_set) { |
82 | VarTouchVisitor visitor(std::move(var_set)); |
83 | visitor(expr); |
84 | return visitor.use_var_; |
85 | } |
86 | |
87 | } // namespace tir |
88 | } // namespace tvm |
89 |