1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#ifndef TVM_SCRIPT_PRINTER_IR_DOCSIFIER_FUNCTOR_H_
20#define TVM_SCRIPT_PRINTER_IR_DOCSIFIER_FUNCTOR_H_
21
22#include <tvm/node/node.h>
23#include <tvm/runtime/logging.h>
24#include <tvm/runtime/packed_func.h>
25
26#include <string>
27#include <type_traits>
28#include <unordered_map>
29#include <utility>
30#include <vector>
31
32namespace tvm {
33namespace script {
34namespace printer {
35
36/*!
37 * \brief Dynamic dispatch functor based on ObjectPath.
38 *
39 * This functor dispatches based on the type of object and the input dispatch token.
40 */
41template <typename R, typename... Args>
42class IRDocsifierFunctor {
43 private:
44 using TSelf = IRDocsifierFunctor<R, Args...>;
45
46 template <class TObjectRef, class TCallable>
47 using IsDispatchFunction =
48 typename std::is_convertible<TCallable, std::function<R(TObjectRef, Args...)>>;
49
50 public:
51 /*!
52 * \brief Call the dispatch function.
53 * \param token The dispatch token.
54 * \param obj The object.
55 * \param args Other args.
56 *
57 * \return The return value of the dispatch function
58 *
59 * If the TObjectRef isn't registered with the token, it will try to find
60 * dispatch function for TObjectRef with the default dispatch token (empty string).
61 */
62 template <class TObjectRef>
63 R operator()(const String& token, TObjectRef obj, Args... args) const {
64 uint32_t type_index = obj.defined() ? obj->type_index() : 0;
65 const runtime::PackedFunc* pf = nullptr;
66 if ((pf = LookupDispatchTable(token, type_index)) != nullptr) {
67 return (*pf)(obj, args...);
68 }
69 if ((pf = LookupDispatchTable("", type_index)) != nullptr) {
70 return (*pf)(obj, args...);
71 }
72 LOG(WARNING) << "ObjectFunctor calls un-registered function on type: "
73 << runtime::Object::TypeIndex2Key(type_index) << " (token: " << token << ")"
74 << ". ObjectType: " << obj->GetTypeKey() << ". Object: " << obj;
75 ICHECK(false) << "ObjectFunctor calls un-registered function on type: "
76 << runtime::Object::TypeIndex2Key(type_index) << " (token: " << token << ")"
77 << ". ObjectType: " << obj->GetTypeKey() << ". Object: " << obj;
78 }
79
80 /*!
81 * \brief Set the dispatch function
82 * \param token The dispatch token.
83 * \param type_index The TVM object type index for this dispatch function.
84 * \param f The dispatch function.
85 *
86 * This takes a type-erased packed function as input. It should be used
87 * through FFI boundary, for example, registering dispatch function from Python.
88 */
89 TSelf& set_dispatch(String token, uint32_t type_index, runtime::PackedFunc f) {
90 std::vector<runtime::PackedFunc>* table = &dispatch_table_[token];
91 if (table->size() <= type_index) {
92 table->resize(type_index + 1, nullptr);
93 }
94 runtime::PackedFunc& slot = (*table)[type_index];
95 if (slot != nullptr) {
96 ICHECK(false) << "Dispatch for type is already registered: "
97 << runtime::Object::TypeIndex2Key(type_index);
98 }
99 slot = f;
100 return *this;
101 }
102
103 /*!
104 * \brief Set the dispatch function
105 * \param token The dispatch token.
106 * \param f The dispatch function.
107 */
108 template <typename TObjectRef, typename TCallable,
109 typename = std::enable_if_t<IsDispatchFunction<TObjectRef, TCallable>::value>>
110 TSelf& set_dispatch(String token, TCallable f) {
111 return set_dispatch(token, TObjectRef::ContainerType::RuntimeTypeIndex(),
112 runtime::TypedPackedFunc<R(TObjectRef, Args...)>(f));
113 }
114
115 /*!
116 * \brief Remove dispatch function
117 * \param token The dispatch token.
118 * \param type_index The TVM object type index for the dispatch function to be removed.
119 *
120 * This is useful when dispatch function comes from other language's runtime, and
121 * those function should be removed before that language runtime shuts down.
122 */
123 void remove_dispatch(String token, uint32_t type_index) {
124 std::vector<runtime::PackedFunc>* table = &dispatch_table_[token];
125 if (table->size() <= type_index) {
126 return;
127 }
128 (*table)[type_index] = nullptr;
129 }
130
131 private:
132 /*!
133 * \brief Look up the dispatch table for the given token and type_index.
134 * \param token The dispatch token.
135 * \param type_index The TVM object type index.
136 * \return Returns the functor if the lookup succeeds, nullptr otherwise.
137 */
138 const runtime::PackedFunc* LookupDispatchTable(const String& token, uint32_t type_index) const {
139 auto it = dispatch_table_.find(token);
140 if (it == dispatch_table_.end()) {
141 return nullptr;
142 }
143 const std::vector<runtime::PackedFunc>& tab = it->second;
144 if (type_index >= tab.size()) {
145 return nullptr;
146 }
147 const PackedFunc* f = &tab[type_index];
148 if (f->defined()) {
149 return f;
150 } else {
151 return nullptr;
152 }
153 }
154 /*
155 * This type alias and the following free functions are created to reduce the binary bloat
156 * from template and also hide implementation details from this header
157 */
158 using DispatchTable = std::unordered_map<std::string, std::vector<runtime::PackedFunc>>;
159 /*! \brief The dispatch table. */
160 DispatchTable dispatch_table_;
161};
162
163} // namespace printer
164} // namespace script
165} // namespace tvm
166#endif // TVM_SCRIPT_PRINTER_IR_DOCSIFIER_FUNCTOR_H_
167