1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | /*! |
20 | * \file tvm/node/functor.h |
21 | * \brief Defines the Functor data structures. |
22 | */ |
23 | #ifndef TVM_NODE_FUNCTOR_H_ |
24 | #define TVM_NODE_FUNCTOR_H_ |
25 | |
26 | #include <dmlc/logging.h> |
27 | #include <tvm/runtime/object.h> |
28 | |
29 | #include <type_traits> |
30 | #include <utility> |
31 | #include <vector> |
32 | |
33 | namespace tvm { |
34 | |
35 | using runtime::ObjectRef; |
36 | |
37 | /*! |
38 | * \brief A dynamically dispatched functor on the type of the first argument. |
39 | * |
40 | * This is a class that is useful to construct polymorphic dispatching |
41 | * base on the AST/IR node's type. |
42 | * |
43 | * \code |
44 | * NodeFunctor<std::string (const ObjectRef& n, std::string prefix)> tostr; |
45 | * tostr.set_dispatch<Add>([](const ObjectRef& op, std::string prefix) { |
46 | * return prefix + "Add"; |
47 | * }); |
48 | * tostr.set_dispatch<IntImm>([](const ObjectRef& op, std::string prefix) { |
49 | * return prefix + "IntImm" |
50 | * }); |
51 | * |
52 | * Expr x = make_const(1); |
53 | * Expr y = x + x; |
54 | * // dispatch to IntImm, outputs "MyIntImm" |
55 | * LOG(INFO) << tostr(x, "My"); |
56 | * // dispatch to IntImm, outputs "MyAdd" |
57 | * LOG(INFO) << tostr(y, "My"); |
58 | * \endcode |
59 | * |
60 | * \tparam FType function signiture |
61 | * This type if only defined for FType with function signature |
62 | */ |
63 | template <typename FType> |
64 | class NodeFunctor; |
65 | |
66 | template <typename R, typename... Args> |
67 | class NodeFunctor<R(const ObjectRef& n, Args...)> { |
68 | private: |
69 | /*! \brief internal function pointer type */ |
70 | typedef R (*FPointer)(const ObjectRef& n, Args...); |
71 | /*! \brief refer to itself. */ |
72 | using TSelf = NodeFunctor<R(const ObjectRef& n, Args...)>; |
73 | /*! \brief internal function table */ |
74 | std::vector<FPointer> func_; |
75 | |
76 | public: |
77 | /*! \brief the result type of this functor */ |
78 | using result_type = R; |
79 | /*! |
80 | * \brief Whether the functor can dispatch the corresponding Node |
81 | * \param n The node to be dispatched |
82 | * \return Whether dispatching function is registered for n's type. |
83 | */ |
84 | bool can_dispatch(const ObjectRef& n) const { |
85 | uint32_t type_index = n->type_index(); |
86 | return type_index < func_.size() && func_[type_index] != nullptr; |
87 | } |
88 | /*! |
89 | * \brief invoke the functor, dispatch on type of n |
90 | * \param n The Node argument |
91 | * \param args The additional arguments |
92 | * \return The result. |
93 | */ |
94 | R operator()(const ObjectRef& n, Args... args) const { |
95 | ICHECK(can_dispatch(n)) << "NodeFunctor calls un-registered function on type " |
96 | << n->GetTypeKey(); |
97 | return (*func_[n->type_index()])(n, std::forward<Args>(args)...); |
98 | } |
99 | /*! |
100 | * \brief set the dispatcher for type TNode |
101 | * \param f The function to be set. |
102 | * \tparam TNode the type of Node to be dispatched. |
103 | * \return reference to self. |
104 | */ |
105 | template <typename TNode> |
106 | TSelf& set_dispatch(FPointer f) { // NOLINT(*) |
107 | uint32_t tindex = TNode::RuntimeTypeIndex(); |
108 | if (func_.size() <= tindex) { |
109 | func_.resize(tindex + 1, nullptr); |
110 | } |
111 | ICHECK(func_[tindex] == nullptr) << "Dispatch for " << TNode::_type_key << " is already set" ; |
112 | func_[tindex] = f; |
113 | return *this; |
114 | } |
115 | /*! |
116 | * \brief unset the dispatcher for type TNode |
117 | * |
118 | * \tparam TNode the type of Node to be dispatched. |
119 | * \return reference to self. |
120 | */ |
121 | template <typename TNode> |
122 | TSelf& clear_dispatch() { // NOLINT(*) |
123 | uint32_t tindex = TNode::RuntimeTypeIndex(); |
124 | ICHECK_LT(tindex, func_.size()) << "clear_dispatch: index out of range" ; |
125 | func_[tindex] = nullptr; |
126 | return *this; |
127 | } |
128 | }; |
129 | |
130 | #define TVM_REG_FUNC_VAR_DEF(ClsName) static TVM_ATTRIBUTE_UNUSED auto& __make_functor##_##ClsName |
131 | |
132 | /*! |
133 | * \brief Useful macro to set NodeFunctor dispatch in a global static field. |
134 | * |
135 | * \code |
136 | * // Use NodeFunctor to implement ReprPrinter similar to Visitor Pattern. |
137 | * // vtable allows easy patch of new Node types, without changing |
138 | * // interface of ReprPrinter. |
139 | * |
140 | * class ReprPrinter { |
141 | * public: |
142 | * std::ostream& stream; |
143 | * // the dispatch function. |
144 | * void print(Expr e) { |
145 | * const static FType& f = *vtable(); |
146 | * f(e, this); |
147 | * } |
148 | * |
149 | * using FType = NodeFunctor<void (const ObjectRef&, ReprPrinter* )>; |
150 | * // function to return global function table |
151 | * static FType& vtable(); |
152 | * }; |
153 | * |
154 | * // in cpp/cc file |
155 | * ReprPrinter::FType& ReprPrinter::vtable() { // NOLINT(*) |
156 | * static FType inst; return inst; |
157 | * } |
158 | * |
159 | * TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
160 | * .set_dispatch<Add>([](const ObjectRef& ref, ReprPrinter* p) { |
161 | * auto* n = static_cast<const Add*>(ref.get()); |
162 | * p->print(n->a); |
163 | * p->stream << '+' |
164 | * p->print(n->b); |
165 | * }); |
166 | * |
167 | * |
168 | * \endcode |
169 | * |
170 | * \param ClsName The name of the class |
171 | * \param FField The static function that returns a singleton of NodeFunctor. |
172 | */ |
173 | #define TVM_STATIC_IR_FUNCTOR(ClsName, FField) \ |
174 | TVM_STR_CONCAT(TVM_REG_FUNC_VAR_DEF(ClsName), __COUNTER__) = ClsName::FField() |
175 | } // namespace tvm |
176 | #endif // TVM_NODE_FUNCTOR_H_ |
177 | |