1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file check_contains.cc
22 * \brief Implementation of the analysis that tells if an expression contains
23 a node that satisfies a given predicate.
24 */
25
26#include "check_contains.h"
27
28#include <tvm/tir/expr.h>
29
30#include <vector>
31
32namespace tvm {
33namespace tir {
34
35/*!
36 * \brief Toplevel (static) function that tells if an expression contains a subexpression that
37 satisfies a given predicate.
38 * \param expr The expression to check
39 * \param predicate The predicate that must be satisfied
40 * \return Whether `expr` contains a subexpression that satisfies `predicate`
41 */
42bool CheckContains::ExprContains(const PrimExpr& expr,
43 std::function<bool(const PrimExpr&)> predicate) {
44 CheckContains check_contains(predicate);
45 check_contains.VisitExpr(expr);
46 return check_contains.contains_it_;
47}
48
49/*!
50 * \brief Toplevel (static) function that tells if a statement contains a subexpression that
51 satisfies a given predicate.
52 * \param stmt The statement to check
53 * \param predicate The predicate that must be satisfied
54 * \return Whether `stmt` contains a subexpression that satisfies `predicate`
55 */
56bool CheckContains::StmtContains(const Stmt& stmt, std::function<bool(const PrimExpr&)> predicate) {
57 CheckContains check_contains(predicate);
58 check_contains.VisitStmt(stmt);
59 return check_contains.contains_it_;
60}
61
62/*!
63 * \brief Protected constructor of CheckContains.
64 * \param predicate The predicate that must be satisfied
65 */
66CheckContains::CheckContains(std::function<bool(const PrimExpr&)> predicate)
67 : predicate_(predicate) {}
68
69/*!
70 * \brief The method which overrides the generic dispatcher of StmtExprVisitor for expressions.
71 * \param expr The expression to visit
72 */
73void CheckContains::VisitExpr(const PrimExpr& expr) {
74 // If the predicate holds on `expr`, we know `expr` contains something which makes
75 // the predicate hold
76 if (predicate_(expr)) {
77 contains_it_ = true;
78 } else {
79 // Otherwise we continue to look for it recursively by calling the dispatcher
80 StmtExprVisitor::VisitExpr(expr);
81 }
82}
83
84/*!
85 * \brief The method which overrides the generic dispatcher of StmtExprVisitor for statements.
86 * \param stmt The statement to visit
87 */
88void CheckContains::VisitStmt(const Stmt& stmt) {
89 // We keep exploring only if `contains_it_` is false
90 if (!contains_it_) {
91 // and in order to do that we call the general dispatcher
92 StmtExprVisitor::VisitStmt(stmt);
93 }
94 // As otherwise we already have our answer
95}
96
97} // namespace tir
98} // namespace tvm
99