1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/relay/dataflow_matcher.h
22 * \brief A pattern matcher for matching dataflow properties.
23 */
24#ifndef TVM_RELAY_DATAFLOW_MATCHER_H_
25#define TVM_RELAY_DATAFLOW_MATCHER_H_
26
27#include <tvm/relay/dataflow_pattern.h>
28#include <tvm/relay/dataflow_pattern_functor.h>
29
30#include <string>
31#include <unordered_map>
32#include <utility>
33
34namespace tvm {
35namespace relay {
36
37class DFPatternCallback;
38/*!
39 * \brief Base type of all dataflow pattern callbacks.
40 * \sa DFPatternCallback
41 */
42class DFPatternCallbackNode : public Object {
43 public:
44 /*! \brief Pattern this callback matches */
45 DFPattern pattern;
46 /*! \brief Function to call when finding a matched expression */
47 PackedFunc function;
48 /*! \brief Require InferType to be run before the callback */
49 bool require_type;
50 /*! \brief Run the callback only once */
51 bool rewrite_once;
52
53 void VisitAttrs(tvm::AttrVisitor* v) {
54 v->Visit("pattern", &pattern);
55 v->Visit("require_type", &require_type);
56 v->Visit("rewrite_once", &rewrite_once);
57 }
58
59 static constexpr const char* _type_key = "DFPatternCallbackNode";
60 TVM_DECLARE_BASE_OBJECT_INFO(DFPatternCallbackNode, Object);
61};
62
63/*!
64 * \brief Managed reference to dataflow pattern callbacks.
65 * \sa DFPatternCallbackNode
66 */
67class DFPatternCallback : public ObjectRef {
68 public:
69 TVM_DLL DFPatternCallback(DFPattern pattern, PackedFunc callback, bool require_type,
70 bool rewrite_once = false);
71 TVM_DEFINE_OBJECT_REF_METHODS(DFPatternCallback, ObjectRef, DFPatternCallbackNode);
72};
73
74/*!
75 * \brief Determine if a pattern matches an expression
76 *
77 * \param pattern The pattern to match
78 * \param expr The expression to match
79 *
80 * \return Return true if the pattern and the expression match, return false otherwise.
81 */
82bool MatchPattern(DFPattern pattern, Expr expr);
83
84/*!
85 * \brief Rewrite an expression based on some number of DFPatternCallbacks
86 *
87 * \param callbacks An array of DFPatternCallback Nodes
88 * \param expr The expression to rewrite
89 * \param mod The module that associates with the expr
90 *
91 * \return Return An Expr with every match of the pattern inside the callbacks rewritten by the
92 * functions inside the callbacks
93 */
94Expr RewritePatterns(Array<DFPatternCallback> callbacks, Expr expr, IRModule mod = IRModule());
95
96/*!
97 * \brief Partition all matches of a DFPattern inside an Expr into separate Function calls
98 *
99 * \param pattern The pattern to match
100 * \param expr The expression to patition
101 * \param attrs A set of parameter names and values to apply to the partitioned function
102 * \param check A callback function for checking more complicated properties of the matched
103 * expressions, returns true if the match is accepted and false otherwise
104 *
105 * \return Return the paritioned Expr.
106 */
107Expr PartitionPattern(DFPattern pattern, Expr expr, Map<String, ObjectRef> attrs, PackedFunc check);
108
109/*!
110 * \brief Infer the type of an expression.
111 *
112 * \param expr The expression to rewrite
113 *
114 * \return Return An Expr with unambiguous type information filled in, as well as it's
115 * checked type field populated with the result type.
116 *
117 */
118Expr InferType(const Expr& expr);
119
120} // namespace relay
121} // namespace tvm
122
123#endif // TVM_RELAY_DATAFLOW_MATCHER_H_
124