1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/transforms/simplify_expr.h
22 * \brief Utility data structures for simplifying Relay expressions.
23 */
24#ifndef TVM_RELAY_TRANSFORMS_SIMPLIFY_EXPR_H_
25#define TVM_RELAY_TRANSFORMS_SIMPLIFY_EXPR_H_
26
27#include <tvm/relay/dataflow_matcher.h>
28#include <tvm/relay/expr.h>
29
30#include <memory>
31#include <vector>
32
33namespace tvm {
34namespace relay {
35
36/*! \brief A wrapper class defining a rewrite matching a specific pattern. */
37class DFPatternRewrite {
38 public:
39 /*! \brief Returns the rewritten expression. */
40 virtual Expr Callback(const Expr& pre, const Expr& post,
41 const Map<DFPattern, Array<Expr>>& node_map) const = 0;
42
43 virtual ~DFPatternRewrite() = default;
44
45 /*! \brief Returns the pattern to be used for matching and rewriting. */
46 inline DFPattern Pattern() const { return pattern_; }
47
48 inline bool RequireType() const { return require_type_; }
49
50 inline DFPatternCallback MakeCallback() const {
51 auto func = [this](TVMArgs args, TVMRetValue* rv) {
52 Expr pre = args[0];
53 Expr post = args[1];
54 Map<DFPattern, Array<Expr>> node_map = args[2];
55 *rv = this->Callback(pre, post, node_map);
56 };
57 return DFPatternCallback(pattern_, PackedFunc(func), require_type_, rewrite_once_);
58 }
59
60 protected:
61 /*! \brief The pattern for matching and rewriting. */
62 DFPattern pattern_;
63 /*! \brief Whether or not the rewrite requires types to be inferred. */
64 bool require_type_ = true;
65 /*! \brief Whether or not run the callback only once */
66 bool rewrite_once_ = false;
67};
68
69/*! \brief Helper class for composing rewrites and getting callbacks. */
70class DFPatternRewriteComposer {
71 public:
72 template <typename T, typename... Args>
73 inline void AddRewrite(Args... args) {
74 rewrites_.push_back(std::make_shared<T, Args&...>(args...));
75 }
76
77 inline Array<DFPatternCallback> MakeCallbacks() const {
78 Array<DFPatternCallback> callbacks;
79 for (const auto& rewrite : rewrites_) {
80 callbacks.push_back(rewrite->MakeCallback());
81 }
82 return callbacks;
83 }
84
85 private:
86 /*! \brief the rewrites to be composed. */
87 std::vector<std::shared_ptr<DFPatternRewrite>> rewrites_;
88};
89
90} // namespace relay
91} // namespace tvm
92
93#endif // TVM_RELAY_TRANSFORMS_SIMPLIFY_EXPR_H_
94