1 | /* |
---|---|

2 | * Licensed to the Apache Software Foundation (ASF) under one |

3 | * or more contributor license agreements. See the NOTICE file |

4 | * distributed with this work for additional information |

5 | * regarding copyright ownership. The ASF licenses this file |

6 | * to you under the Apache License, Version 2.0 (the |

7 | * "License"); you may not use this file except in compliance |

8 | * with the License. You may obtain a copy of the License at |

9 | * |

10 | * http://www.apache.org/licenses/LICENSE-2.0 |

11 | * |

12 | * Unless required by applicable law or agreed to in writing, |

13 | * software distributed under the License is distributed on an |

14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |

15 | * KIND, either express or implied. See the License for the |

16 | * specific language governing permissions and limitations |

17 | * under the License. |

18 | */ |

19 | /*! |

20 | * \file tvm/node/functor.h |

21 | * \brief Defines the Functor data structures. |

22 | */ |

23 | #ifndef TVM_NODE_FUNCTOR_H_ |

24 | #define TVM_NODE_FUNCTOR_H_ |

25 | |

26 | #include <dmlc/logging.h> |

27 | #include <tvm/runtime/object.h> |

28 | |

29 | #include <type_traits> |

30 | #include <utility> |

31 | #include <vector> |

32 | |

33 | namespace tvm { |

34 | |

35 | using runtime::ObjectRef; |

36 | |

37 | /*! |

38 | * \brief A dynamically dispatched functor on the type of the first argument. |

39 | * |

40 | * This is a class that is useful to construct polymorphic dispatching |

41 | * base on the AST/IR node's type. |

42 | * |

43 | * \code |

44 | * NodeFunctor<std::string (const ObjectRef& n, std::string prefix)> tostr; |

45 | * tostr.set_dispatch<Add>([](const ObjectRef& op, std::string prefix) { |

46 | * return prefix + "Add"; |

47 | * }); |

48 | * tostr.set_dispatch<IntImm>([](const ObjectRef& op, std::string prefix) { |

49 | * return prefix + "IntImm" |

50 | * }); |

51 | * |

52 | * Expr x = make_const(1); |

53 | * Expr y = x + x; |

54 | * // dispatch to IntImm, outputs "MyIntImm" |

55 | * LOG(INFO) << tostr(x, "My"); |

56 | * // dispatch to IntImm, outputs "MyAdd" |

57 | * LOG(INFO) << tostr(y, "My"); |

58 | * \endcode |

59 | * |

60 | * \tparam FType function signiture |

61 | * This type if only defined for FType with function signature |

62 | */ |

63 | template <typename FType> |

64 | class NodeFunctor; |

65 | |

66 | template <typename R, typename... Args> |

67 | class NodeFunctor<R(const ObjectRef& n, Args...)> { |

68 | private: |

69 | /*! \brief internal function pointer type */ |

70 | typedef R (*FPointer)(const ObjectRef& n, Args...); |

71 | /*! \brief refer to itself. */ |

72 | using TSelf = NodeFunctor<R(const ObjectRef& n, Args...)>; |

73 | /*! \brief internal function table */ |

74 | std::vector<FPointer> func_; |

75 | |

76 | public: |

77 | /*! \brief the result type of this functor */ |

78 | using result_type = R; |

79 | /*! |

80 | * \brief Whether the functor can dispatch the corresponding Node |

81 | * \param n The node to be dispatched |

82 | * \return Whether dispatching function is registered for n's type. |

83 | */ |

84 | bool can_dispatch(const ObjectRef& n) const { |

85 | uint32_t type_index = n->type_index(); |

86 | return type_index < func_.size() && func_[type_index] != nullptr; |

87 | } |

88 | /*! |

89 | * \brief invoke the functor, dispatch on type of n |

90 | * \param n The Node argument |

91 | * \param args The additional arguments |

92 | * \return The result. |

93 | */ |

94 | R operator()(const ObjectRef& n, Args... args) const { |

95 | ICHECK(can_dispatch(n)) << "NodeFunctor calls un-registered function on type " |

96 | << n->GetTypeKey(); |

97 | return (*func_[n->type_index()])(n, std::forward<Args>(args)...); |

98 | } |

99 | /*! |

100 | * \brief set the dispatcher for type TNode |

101 | * \param f The function to be set. |

102 | * \tparam TNode the type of Node to be dispatched. |

103 | * \return reference to self. |

104 | */ |

105 | template <typename TNode> |

106 | TSelf& set_dispatch(FPointer f) { // NOLINT(*) |

107 | uint32_t tindex = TNode::RuntimeTypeIndex(); |

108 | if (func_.size() <= tindex) { |

109 | func_.resize(tindex + 1, nullptr); |

110 | } |

111 | ICHECK(func_[tindex] == nullptr) << "Dispatch for "<< TNode::_type_key << " is already set"; |

112 | func_[tindex] = f; |

113 | return *this; |

114 | } |

115 | /*! |

116 | * \brief unset the dispatcher for type TNode |

117 | * |

118 | * \tparam TNode the type of Node to be dispatched. |

119 | * \return reference to self. |

120 | */ |

121 | template <typename TNode> |

122 | TSelf& clear_dispatch() { // NOLINT(*) |

123 | uint32_t tindex = TNode::RuntimeTypeIndex(); |

124 | ICHECK_LT(tindex, func_.size()) << "clear_dispatch: index out of range"; |

125 | func_[tindex] = nullptr; |

126 | return *this; |

127 | } |

128 | }; |

129 | |

130 | #define TVM_REG_FUNC_VAR_DEF(ClsName) static TVM_ATTRIBUTE_UNUSED auto& __make_functor##_##ClsName |

131 | |

132 | /*! |

133 | * \brief Useful macro to set NodeFunctor dispatch in a global static field. |

134 | * |

135 | * \code |

136 | * // Use NodeFunctor to implement ReprPrinter similar to Visitor Pattern. |

137 | * // vtable allows easy patch of new Node types, without changing |

138 | * // interface of ReprPrinter. |

139 | * |

140 | * class ReprPrinter { |

141 | * public: |

142 | * std::ostream& stream; |

143 | * // the dispatch function. |

144 | * void print(Expr e) { |

145 | * const static FType& f = *vtable(); |

146 | * f(e, this); |

147 | * } |

148 | * |

149 | * using FType = NodeFunctor<void (const ObjectRef&, ReprPrinter* )>; |

150 | * // function to return global function table |

151 | * static FType& vtable(); |

152 | * }; |

153 | * |

154 | * // in cpp/cc file |

155 | * ReprPrinter::FType& ReprPrinter::vtable() { // NOLINT(*) |

156 | * static FType inst; return inst; |

157 | * } |

158 | * |

159 | * TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |

160 | * .set_dispatch<Add>([](const ObjectRef& ref, ReprPrinter* p) { |

161 | * auto* n = static_cast<const Add*>(ref.get()); |

162 | * p->print(n->a); |

163 | * p->stream << '+' |

164 | * p->print(n->b); |

165 | * }); |

166 | * |

167 | * |

168 | * \endcode |

169 | * |

170 | * \param ClsName The name of the class |

171 | * \param FField The static function that returns a singleton of NodeFunctor. |

172 | */ |

173 | #define TVM_STATIC_IR_FUNCTOR(ClsName, FField) \ |

174 | TVM_STR_CONCAT(TVM_REG_FUNC_VAR_DEF(ClsName), __COUNTER__) = ClsName::FField() |

175 | } // namespace tvm |

176 | #endif // TVM_NODE_FUNCTOR_H_ |

177 |