1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/transforms/simplify_expr.h |
22 | * \brief Utility data structures for simplifying Relay expressions. |
23 | */ |
24 | #ifndef TVM_RELAY_TRANSFORMS_SIMPLIFY_EXPR_H_ |
25 | #define TVM_RELAY_TRANSFORMS_SIMPLIFY_EXPR_H_ |
26 | |
27 | #include <tvm/relay/dataflow_matcher.h> |
28 | #include <tvm/relay/expr.h> |
29 | |
30 | #include <memory> |
31 | #include <vector> |
32 | |
33 | namespace tvm { |
34 | namespace relay { |
35 | |
36 | /*! \brief A wrapper class defining a rewrite matching a specific pattern. */ |
37 | class DFPatternRewrite { |
38 | public: |
39 | /*! \brief Returns the rewritten expression. */ |
40 | virtual Expr Callback(const Expr& pre, const Expr& post, |
41 | const Map<DFPattern, Array<Expr>>& node_map) const = 0; |
42 | |
43 | virtual ~DFPatternRewrite() = default; |
44 | |
45 | /*! \brief Returns the pattern to be used for matching and rewriting. */ |
46 | inline DFPattern Pattern() const { return pattern_; } |
47 | |
48 | inline bool RequireType() const { return require_type_; } |
49 | |
50 | inline DFPatternCallback MakeCallback() const { |
51 | auto func = [this](TVMArgs args, TVMRetValue* rv) { |
52 | Expr pre = args[0]; |
53 | Expr post = args[1]; |
54 | Map<DFPattern, Array<Expr>> node_map = args[2]; |
55 | *rv = this->Callback(pre, post, node_map); |
56 | }; |
57 | return DFPatternCallback(pattern_, PackedFunc(func), require_type_, rewrite_once_); |
58 | } |
59 | |
60 | protected: |
61 | /*! \brief The pattern for matching and rewriting. */ |
62 | DFPattern pattern_; |
63 | /*! \brief Whether or not the rewrite requires types to be inferred. */ |
64 | bool require_type_ = true; |
65 | /*! \brief Whether or not run the callback only once */ |
66 | bool rewrite_once_ = false; |
67 | }; |
68 | |
69 | /*! \brief Helper class for composing rewrites and getting callbacks. */ |
70 | class DFPatternRewriteComposer { |
71 | public: |
72 | template <typename T, typename... Args> |
73 | inline void AddRewrite(Args... args) { |
74 | rewrites_.push_back(std::make_shared<T, Args&...>(args...)); |
75 | } |
76 | |
77 | inline Array<DFPatternCallback> MakeCallbacks() const { |
78 | Array<DFPatternCallback> callbacks; |
79 | for (const auto& rewrite : rewrites_) { |
80 | callbacks.push_back(rewrite->MakeCallback()); |
81 | } |
82 | return callbacks; |
83 | } |
84 | |
85 | private: |
86 | /*! \brief the rewrites to be composed. */ |
87 | std::vector<std::shared_ptr<DFPatternRewrite>> rewrites_; |
88 | }; |
89 | |
90 | } // namespace relay |
91 | } // namespace tvm |
92 | |
93 | #endif // TVM_RELAY_TRANSFORMS_SIMPLIFY_EXPR_H_ |
94 | |