1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19/*!
20 * \file tvm/node/functor.h
21 * \brief Defines the Functor data structures.
22 */
23#ifndef TVM_NODE_FUNCTOR_H_
24#define TVM_NODE_FUNCTOR_H_
25
26#include <dmlc/logging.h>
27#include <tvm/runtime/object.h>
28
29#include <type_traits>
30#include <utility>
31#include <vector>
32
33namespace tvm {
34
35using runtime::ObjectRef;
36
37/*!
38 * \brief A dynamically dispatched functor on the type of the first argument.
39 *
40 * This is a class that is useful to construct polymorphic dispatching
41 * base on the AST/IR node's type.
42 *
43 * \code
44 * NodeFunctor<std::string (const ObjectRef& n, std::string prefix)> tostr;
45 * tostr.set_dispatch<Add>([](const ObjectRef& op, std::string prefix) {
46 * return prefix + "Add";
47 * });
48 * tostr.set_dispatch<IntImm>([](const ObjectRef& op, std::string prefix) {
49 * return prefix + "IntImm"
50 * });
51 *
52 * Expr x = make_const(1);
53 * Expr y = x + x;
54 * // dispatch to IntImm, outputs "MyIntImm"
55 * LOG(INFO) << tostr(x, "My");
56 * // dispatch to IntImm, outputs "MyAdd"
57 * LOG(INFO) << tostr(y, "My");
58 * \endcode
59 *
60 * \tparam FType function signiture
61 * This type if only defined for FType with function signature
62 */
63template <typename FType>
64class NodeFunctor;
65
66template <typename R, typename... Args>
67class NodeFunctor<R(const ObjectRef& n, Args...)> {
68 private:
69 /*! \brief internal function pointer type */
70 typedef R (*FPointer)(const ObjectRef& n, Args...);
71 /*! \brief refer to itself. */
72 using TSelf = NodeFunctor<R(const ObjectRef& n, Args...)>;
73 /*! \brief internal function table */
74 std::vector<FPointer> func_;
75
76 public:
77 /*! \brief the result type of this functor */
78 using result_type = R;
79 /*!
80 * \brief Whether the functor can dispatch the corresponding Node
81 * \param n The node to be dispatched
82 * \return Whether dispatching function is registered for n's type.
83 */
84 bool can_dispatch(const ObjectRef& n) const {
85 uint32_t type_index = n->type_index();
86 return type_index < func_.size() && func_[type_index] != nullptr;
87 }
88 /*!
89 * \brief invoke the functor, dispatch on type of n
90 * \param n The Node argument
91 * \param args The additional arguments
92 * \return The result.
93 */
94 R operator()(const ObjectRef& n, Args... args) const {
95 ICHECK(can_dispatch(n)) << "NodeFunctor calls un-registered function on type "
96 << n->GetTypeKey();
97 return (*func_[n->type_index()])(n, std::forward<Args>(args)...);
98 }
99 /*!
100 * \brief set the dispatcher for type TNode
101 * \param f The function to be set.
102 * \tparam TNode the type of Node to be dispatched.
103 * \return reference to self.
104 */
105 template <typename TNode>
106 TSelf& set_dispatch(FPointer f) { // NOLINT(*)
107 uint32_t tindex = TNode::RuntimeTypeIndex();
108 if (func_.size() <= tindex) {
109 func_.resize(tindex + 1, nullptr);
110 }
111 ICHECK(func_[tindex] == nullptr) << "Dispatch for " << TNode::_type_key << " is already set";
112 func_[tindex] = f;
113 return *this;
114 }
115 /*!
116 * \brief unset the dispatcher for type TNode
117 *
118 * \tparam TNode the type of Node to be dispatched.
119 * \return reference to self.
120 */
121 template <typename TNode>
122 TSelf& clear_dispatch() { // NOLINT(*)
123 uint32_t tindex = TNode::RuntimeTypeIndex();
124 ICHECK_LT(tindex, func_.size()) << "clear_dispatch: index out of range";
125 func_[tindex] = nullptr;
126 return *this;
127 }
128};
129
130#define TVM_REG_FUNC_VAR_DEF(ClsName) static TVM_ATTRIBUTE_UNUSED auto& __make_functor##_##ClsName
131
132/*!
133 * \brief Useful macro to set NodeFunctor dispatch in a global static field.
134 *
135 * \code
136 * // Use NodeFunctor to implement ReprPrinter similar to Visitor Pattern.
137 * // vtable allows easy patch of new Node types, without changing
138 * // interface of ReprPrinter.
139 *
140 * class ReprPrinter {
141 * public:
142 * std::ostream& stream;
143 * // the dispatch function.
144 * void print(Expr e) {
145 * const static FType& f = *vtable();
146 * f(e, this);
147 * }
148 *
149 * using FType = NodeFunctor<void (const ObjectRef&, ReprPrinter* )>;
150 * // function to return global function table
151 * static FType& vtable();
152 * };
153 *
154 * // in cpp/cc file
155 * ReprPrinter::FType& ReprPrinter::vtable() { // NOLINT(*)
156 * static FType inst; return inst;
157 * }
158 *
159 * TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
160 * .set_dispatch<Add>([](const ObjectRef& ref, ReprPrinter* p) {
161 * auto* n = static_cast<const Add*>(ref.get());
162 * p->print(n->a);
163 * p->stream << '+'
164 * p->print(n->b);
165 * });
166 *
167 *
168 * \endcode
169 *
170 * \param ClsName The name of the class
171 * \param FField The static function that returns a singleton of NodeFunctor.
172 */
173#define TVM_STATIC_IR_FUNCTOR(ClsName, FField) \
174 TVM_STR_CONCAT(TVM_REG_FUNC_VAR_DEF(ClsName), __COUNTER__) = ClsName::FField()
175} // namespace tvm
176#endif // TVM_NODE_FUNCTOR_H_
177