1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file check_contains.cc |
22 | * \brief Implementation of the analysis that tells if an expression contains |
23 | a node that satisfies a given predicate. |
24 | */ |
25 | |
26 | #include "check_contains.h" |
27 | |
28 | #include <tvm/tir/expr.h> |
29 | |
30 | #include <vector> |
31 | |
32 | namespace tvm { |
33 | namespace tir { |
34 | |
35 | /*! |
36 | * \brief Toplevel (static) function that tells if an expression contains a subexpression that |
37 | satisfies a given predicate. |
38 | * \param expr The expression to check |
39 | * \param predicate The predicate that must be satisfied |
40 | * \return Whether `expr` contains a subexpression that satisfies `predicate` |
41 | */ |
42 | bool CheckContains::ExprContains(const PrimExpr& expr, |
43 | std::function<bool(const PrimExpr&)> predicate) { |
44 | CheckContains check_contains(predicate); |
45 | check_contains.VisitExpr(expr); |
46 | return check_contains.contains_it_; |
47 | } |
48 | |
49 | /*! |
50 | * \brief Toplevel (static) function that tells if a statement contains a subexpression that |
51 | satisfies a given predicate. |
52 | * \param stmt The statement to check |
53 | * \param predicate The predicate that must be satisfied |
54 | * \return Whether `stmt` contains a subexpression that satisfies `predicate` |
55 | */ |
56 | bool CheckContains::StmtContains(const Stmt& stmt, std::function<bool(const PrimExpr&)> predicate) { |
57 | CheckContains check_contains(predicate); |
58 | check_contains.VisitStmt(stmt); |
59 | return check_contains.contains_it_; |
60 | } |
61 | |
62 | /*! |
63 | * \brief Protected constructor of CheckContains. |
64 | * \param predicate The predicate that must be satisfied |
65 | */ |
66 | CheckContains::CheckContains(std::function<bool(const PrimExpr&)> predicate) |
67 | : predicate_(predicate) {} |
68 | |
69 | /*! |
70 | * \brief The method which overrides the generic dispatcher of StmtExprVisitor for expressions. |
71 | * \param expr The expression to visit |
72 | */ |
73 | void CheckContains::VisitExpr(const PrimExpr& expr) { |
74 | // If the predicate holds on `expr`, we know `expr` contains something which makes |
75 | // the predicate hold |
76 | if (predicate_(expr)) { |
77 | contains_it_ = true; |
78 | } else { |
79 | // Otherwise we continue to look for it recursively by calling the dispatcher |
80 | StmtExprVisitor::VisitExpr(expr); |
81 | } |
82 | } |
83 | |
84 | /*! |
85 | * \brief The method which overrides the generic dispatcher of StmtExprVisitor for statements. |
86 | * \param stmt The statement to visit |
87 | */ |
88 | void CheckContains::VisitStmt(const Stmt& stmt) { |
89 | // We keep exploring only if `contains_it_` is false |
90 | if (!contains_it_) { |
91 | // and in order to do that we call the general dispatcher |
92 | StmtExprVisitor::VisitStmt(stmt); |
93 | } |
94 | // As otherwise we already have our answer |
95 | } |
96 | |
97 | } // namespace tir |
98 | } // namespace tvm |
99 | |