1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #ifndef TVM_SCRIPT_PRINTER_IR_DOCSIFIER_FUNCTOR_H_ |
20 | #define TVM_SCRIPT_PRINTER_IR_DOCSIFIER_FUNCTOR_H_ |
21 | |
22 | #include <tvm/node/node.h> |
23 | #include <tvm/runtime/logging.h> |
24 | #include <tvm/runtime/packed_func.h> |
25 | |
26 | #include <string> |
27 | #include <type_traits> |
28 | #include <unordered_map> |
29 | #include <utility> |
30 | #include <vector> |
31 | |
32 | namespace tvm { |
33 | namespace script { |
34 | namespace printer { |
35 | |
36 | /*! |
37 | * \brief Dynamic dispatch functor based on ObjectPath. |
38 | * |
39 | * This functor dispatches based on the type of object and the input dispatch token. |
40 | */ |
41 | template <typename R, typename... Args> |
42 | class IRDocsifierFunctor { |
43 | private: |
44 | using TSelf = IRDocsifierFunctor<R, Args...>; |
45 | |
46 | template <class TObjectRef, class TCallable> |
47 | using IsDispatchFunction = |
48 | typename std::is_convertible<TCallable, std::function<R(TObjectRef, Args...)>>; |
49 | |
50 | public: |
51 | /*! |
52 | * \brief Call the dispatch function. |
53 | * \param token The dispatch token. |
54 | * \param obj The object. |
55 | * \param args Other args. |
56 | * |
57 | * \return The return value of the dispatch function |
58 | * |
59 | * If the TObjectRef isn't registered with the token, it will try to find |
60 | * dispatch function for TObjectRef with the default dispatch token (empty string). |
61 | */ |
62 | template <class TObjectRef> |
63 | R operator()(const String& token, TObjectRef obj, Args... args) const { |
64 | uint32_t type_index = obj.defined() ? obj->type_index() : 0; |
65 | const runtime::PackedFunc* pf = nullptr; |
66 | if ((pf = LookupDispatchTable(token, type_index)) != nullptr) { |
67 | return (*pf)(obj, args...); |
68 | } |
69 | if ((pf = LookupDispatchTable("" , type_index)) != nullptr) { |
70 | return (*pf)(obj, args...); |
71 | } |
72 | LOG(WARNING) << "ObjectFunctor calls un-registered function on type: " |
73 | << runtime::Object::TypeIndex2Key(type_index) << " (token: " << token << ")" |
74 | << ". ObjectType: " << obj->GetTypeKey() << ". Object: " << obj; |
75 | ICHECK(false) << "ObjectFunctor calls un-registered function on type: " |
76 | << runtime::Object::TypeIndex2Key(type_index) << " (token: " << token << ")" |
77 | << ". ObjectType: " << obj->GetTypeKey() << ". Object: " << obj; |
78 | } |
79 | |
80 | /*! |
81 | * \brief Set the dispatch function |
82 | * \param token The dispatch token. |
83 | * \param type_index The TVM object type index for this dispatch function. |
84 | * \param f The dispatch function. |
85 | * |
86 | * This takes a type-erased packed function as input. It should be used |
87 | * through FFI boundary, for example, registering dispatch function from Python. |
88 | */ |
89 | TSelf& set_dispatch(String token, uint32_t type_index, runtime::PackedFunc f) { |
90 | std::vector<runtime::PackedFunc>* table = &dispatch_table_[token]; |
91 | if (table->size() <= type_index) { |
92 | table->resize(type_index + 1, nullptr); |
93 | } |
94 | runtime::PackedFunc& slot = (*table)[type_index]; |
95 | if (slot != nullptr) { |
96 | ICHECK(false) << "Dispatch for type is already registered: " |
97 | << runtime::Object::TypeIndex2Key(type_index); |
98 | } |
99 | slot = f; |
100 | return *this; |
101 | } |
102 | |
103 | /*! |
104 | * \brief Set the dispatch function |
105 | * \param token The dispatch token. |
106 | * \param f The dispatch function. |
107 | */ |
108 | template <typename TObjectRef, typename TCallable, |
109 | typename = std::enable_if_t<IsDispatchFunction<TObjectRef, TCallable>::value>> |
110 | TSelf& set_dispatch(String token, TCallable f) { |
111 | return set_dispatch(token, TObjectRef::ContainerType::RuntimeTypeIndex(), |
112 | runtime::TypedPackedFunc<R(TObjectRef, Args...)>(f)); |
113 | } |
114 | |
115 | /*! |
116 | * \brief Remove dispatch function |
117 | * \param token The dispatch token. |
118 | * \param type_index The TVM object type index for the dispatch function to be removed. |
119 | * |
120 | * This is useful when dispatch function comes from other language's runtime, and |
121 | * those function should be removed before that language runtime shuts down. |
122 | */ |
123 | void remove_dispatch(String token, uint32_t type_index) { |
124 | std::vector<runtime::PackedFunc>* table = &dispatch_table_[token]; |
125 | if (table->size() <= type_index) { |
126 | return; |
127 | } |
128 | (*table)[type_index] = nullptr; |
129 | } |
130 | |
131 | private: |
132 | /*! |
133 | * \brief Look up the dispatch table for the given token and type_index. |
134 | * \param token The dispatch token. |
135 | * \param type_index The TVM object type index. |
136 | * \return Returns the functor if the lookup succeeds, nullptr otherwise. |
137 | */ |
138 | const runtime::PackedFunc* LookupDispatchTable(const String& token, uint32_t type_index) const { |
139 | auto it = dispatch_table_.find(token); |
140 | if (it == dispatch_table_.end()) { |
141 | return nullptr; |
142 | } |
143 | const std::vector<runtime::PackedFunc>& tab = it->second; |
144 | if (type_index >= tab.size()) { |
145 | return nullptr; |
146 | } |
147 | const PackedFunc* f = &tab[type_index]; |
148 | if (f->defined()) { |
149 | return f; |
150 | } else { |
151 | return nullptr; |
152 | } |
153 | } |
154 | /* |
155 | * This type alias and the following free functions are created to reduce the binary bloat |
156 | * from template and also hide implementation details from this header |
157 | */ |
158 | using DispatchTable = std::unordered_map<std::string, std::vector<runtime::PackedFunc>>; |
159 | /*! \brief The dispatch table. */ |
160 | DispatchTable dispatch_table_; |
161 | }; |
162 | |
163 | } // namespace printer |
164 | } // namespace script |
165 | } // namespace tvm |
166 | #endif // TVM_SCRIPT_PRINTER_IR_DOCSIFIER_FUNCTOR_H_ |
167 | |