1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/runtime/registry.h |
22 | * \brief This file defines the TVM global function registry. |
23 | * |
24 | * The registered functions will be made available to front-end |
25 | * as well as backend users. |
26 | * |
27 | * The registry stores type-erased functions. |
28 | * Each registered function is automatically exposed |
29 | * to front-end language(e.g. python). |
30 | * |
31 | * Front-end can also pass callbacks as PackedFunc, or register |
32 | * then into the same global registry in C++. |
33 | * The goal is to mix the front-end language and the TVM back-end. |
34 | * |
35 | * \code |
36 | * // register the function as MyAPIFuncName |
37 | * TVM_REGISTER_GLOBAL(MyAPIFuncName) |
38 | * .set_body([](TVMArgs args, TVMRetValue* rv) { |
39 | * // my code. |
40 | * }); |
41 | * \endcode |
42 | */ |
43 | #ifndef TVM_RUNTIME_REGISTRY_H_ |
44 | #define TVM_RUNTIME_REGISTRY_H_ |
45 | |
46 | #include <tvm/runtime/packed_func.h> |
47 | |
48 | #include <string> |
49 | #include <type_traits> |
50 | #include <utility> |
51 | #include <vector> |
52 | |
53 | namespace tvm { |
54 | namespace runtime { |
55 | |
56 | /*! |
57 | * \brief Check if signals have been sent to the process and if so |
58 | * invoke the registered signal handler in the frontend environment. |
59 | * |
60 | * When running TVM in another language (Python), the signal handler |
61 | * may not be immediately executed, but instead the signal is marked |
62 | * in the interpreter state (to ensure non-blocking of the signal handler). |
63 | * |
64 | * This function can be explicitly invoked to check the cached signal |
65 | * and run the related processing if a signal is marked. |
66 | * |
67 | * On Linux, when siginterrupt() is set, invoke this function whenever a syscall returns EINTR. |
68 | * When it is not set, invoke it between long-running syscalls when you will not immediately |
69 | * return to the frontend. On Windows, the same rules apply, but due to differences in signal |
70 | * processing, these are likely to only make a difference when used with Ctrl+C and socket calls. |
71 | * |
72 | * Not inserting this function will not cause any correctness |
73 | * issue, but will delay invoking the Python-side signal handler until the function returns to |
74 | * the Python side. This means that the effect of e.g. pressing Ctrl+C or sending signals the |
75 | * process will be delayed until function return. When a C function is blocked on a syscall |
76 | * such as accept(), it needs to be called when EINTR is received. |
77 | * So this function is not needed in most API functions, which can finish quickly in a |
78 | * reasonable, deterministic amount of time. |
79 | * |
80 | * \code |
81 | * |
82 | * int check_signal_every_k_iter = 10; |
83 | * |
84 | * for (int iter = 0; iter < very_large_number; ++iter) { |
85 | * if (iter % check_signal_every_k_iter == 0) { |
86 | * tvm::runtime::EnvCheckSignals(); |
87 | * } |
88 | * // do work here |
89 | * } |
90 | * |
91 | * \endcode |
92 | * |
93 | * \note This function is a nop when no PyErr_CheckSignals is registered. |
94 | * |
95 | * \throws This function throws an exception when the frontend signal handler |
96 | * indicate an error happens, otherwise it returns normally. |
97 | */ |
98 | TVM_DLL void EnvCheckSignals(); |
99 | |
100 | /*! \brief Registry for global function */ |
101 | class Registry { |
102 | public: |
103 | /*! |
104 | * \brief set the body of the function to be f |
105 | * \param f The body of the function. |
106 | */ |
107 | TVM_DLL Registry& set_body(PackedFunc f); // NOLINT(*) |
108 | /*! |
109 | * \brief set the body of the function to be f |
110 | * \param f The body of the function. |
111 | */ |
112 | template <typename TCallable, |
113 | typename = typename std::enable_if_t< |
114 | std::is_convertible<TCallable, std::function<void(TVMArgs, TVMRetValue*)>>::value && |
115 | !std::is_base_of<PackedFunc, TCallable>::value>> |
116 | Registry& set_body(TCallable f) { // NOLINT(*) |
117 | return set_body(PackedFunc(f)); |
118 | } |
119 | /*! |
120 | * \brief set the body of the function to the given function. |
121 | * Note that this will ignore default arg values and always require all arguments to be |
122 | * provided. |
123 | * |
124 | * \code |
125 | * |
126 | * int multiply(int x, int y) { |
127 | * return x * y; |
128 | * } |
129 | * |
130 | * TVM_REGISTER_GLOBAL("multiply") |
131 | * .set_body_typed(multiply); // will have type int(int, int) |
132 | * |
133 | * // will have type int(int, int) |
134 | * TVM_REGISTER_GLOBAL("sub") |
135 | * .set_body_typed([](int a, int b) -> int { return a - b; }); |
136 | * |
137 | * \endcode |
138 | * |
139 | * \param f The function to forward to. |
140 | * \tparam FLambda The signature of the function. |
141 | */ |
142 | template <typename FLambda> |
143 | Registry& set_body_typed(FLambda f) { |
144 | using FType = typename detail::function_signature<FLambda>::FType; |
145 | return set_body(TypedPackedFunc<FType>(std::move(f), name_).packed()); |
146 | } |
147 | /*! |
148 | * \brief set the body of the function to be the passed method pointer. |
149 | * Note that this will ignore default arg values and always require all arguments to be |
150 | * provided. |
151 | * |
152 | * \code |
153 | * |
154 | * // node subclass: |
155 | * struct Example { |
156 | * int doThing(int x); |
157 | * } |
158 | * TVM_REGISTER_GLOBAL("Example_doThing") |
159 | * .set_body_method(&Example::doThing); // will have type int(Example, int) |
160 | * |
161 | * \endcode |
162 | * |
163 | * \param f the method pointer to forward to. |
164 | * \tparam T the type containing the method (inferred). |
165 | * \tparam R the return type of the function (inferred). |
166 | * \tparam Args the argument types of the function (inferred). |
167 | */ |
168 | template <typename T, typename R, typename... Args> |
169 | Registry& set_body_method(R (T::*f)(Args...)) { |
170 | using R_ = typename std::remove_reference<R>::type; |
171 | auto fwrap = [f](T target, Args... params) -> R_ { |
172 | // call method pointer |
173 | return (target.*f)(params...); |
174 | }; |
175 | return set_body(TypedPackedFunc<R_(T, Args...)>(fwrap, name_)); |
176 | } |
177 | |
178 | /*! |
179 | * \brief set the body of the function to be the passed method pointer. |
180 | * Note that this will ignore default arg values and always require all arguments to be |
181 | * provided. |
182 | * |
183 | * \code |
184 | * |
185 | * // node subclass: |
186 | * struct Example { |
187 | * int doThing(int x); |
188 | * } |
189 | * TVM_REGISTER_GLOBAL("Example_doThing") |
190 | * .set_body_method(&Example::doThing); // will have type int(Example, int) |
191 | * |
192 | * \endcode |
193 | * |
194 | * \param f the method pointer to forward to. |
195 | * \tparam T the type containing the method (inferred). |
196 | * \tparam R the return type of the function (inferred). |
197 | * \tparam Args the argument types of the function (inferred). |
198 | */ |
199 | template <typename T, typename R, typename... Args> |
200 | Registry& set_body_method(R (T::*f)(Args...) const) { |
201 | auto fwrap = [f](const T target, Args... params) -> R { |
202 | // call method pointer |
203 | return (target.*f)(params...); |
204 | }; |
205 | return set_body(TypedPackedFunc<R(const T, Args...)>(fwrap, name_)); |
206 | } |
207 | |
208 | /*! |
209 | * \brief set the body of the function to be the passed method pointer. |
210 | * Used when calling a method on a Node subclass through a ObjectRef subclass. |
211 | * Note that this will ignore default arg values and always require all arguments to be |
212 | * provided. |
213 | * |
214 | * \code |
215 | * |
216 | * // node subclass: |
217 | * struct ExampleNode: BaseNode { |
218 | * int doThing(int x); |
219 | * } |
220 | * |
221 | * // noderef subclass |
222 | * struct Example; |
223 | * |
224 | * TVM_REGISTER_GLOBAL("Example_doThing") |
225 | * .set_body_method<Example>(&ExampleNode::doThing); // will have type int(Example, int) |
226 | * |
227 | * // note that just doing: |
228 | * // .set_body_method(&ExampleNode::doThing); |
229 | * // wouldn't work, because ExampleNode can't be taken from a TVMArgValue. |
230 | * |
231 | * \endcode |
232 | * |
233 | * \param f the method pointer to forward to. |
234 | * \tparam TObjectRef the node reference type to call the method on |
235 | * \tparam TNode the node type containing the method (inferred). |
236 | * \tparam R the return type of the function (inferred). |
237 | * \tparam Args the argument types of the function (inferred). |
238 | */ |
239 | template <typename TObjectRef, typename TNode, typename R, typename... Args, |
240 | typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type> |
241 | Registry& set_body_method(R (TNode::*f)(Args...)) { |
242 | auto fwrap = [f](TObjectRef ref, Args... params) { |
243 | TNode* target = ref.operator->(); |
244 | // call method pointer |
245 | return (target->*f)(params...); |
246 | }; |
247 | return set_body(TypedPackedFunc<R(TObjectRef, Args...)>(fwrap, name_)); |
248 | } |
249 | |
250 | /*! |
251 | * \brief set the body of the function to be the passed method pointer. |
252 | * Used when calling a method on a Node subclass through a ObjectRef subclass. |
253 | * Note that this will ignore default arg values and always require all arguments to be |
254 | * provided. |
255 | * |
256 | * \code |
257 | * |
258 | * // node subclass: |
259 | * struct ExampleNode: BaseNode { |
260 | * int doThing(int x); |
261 | * } |
262 | * |
263 | * // noderef subclass |
264 | * struct Example; |
265 | * |
266 | * TVM_REGISTER_GLOBAL("Example_doThing") |
267 | * .set_body_method<Example>(&ExampleNode::doThing); // will have type int(Example, int) |
268 | * |
269 | * // note that just doing: |
270 | * // .set_body_method(&ExampleNode::doThing); |
271 | * // wouldn't work, because ExampleNode can't be taken from a TVMArgValue. |
272 | * |
273 | * \endcode |
274 | * |
275 | * \param f the method pointer to forward to. |
276 | * \tparam TObjectRef the node reference type to call the method on |
277 | * \tparam TNode the node type containing the method (inferred). |
278 | * \tparam R the return type of the function (inferred). |
279 | * \tparam Args the argument types of the function (inferred). |
280 | */ |
281 | template <typename TObjectRef, typename TNode, typename R, typename... Args, |
282 | typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type> |
283 | Registry& set_body_method(R (TNode::*f)(Args...) const) { |
284 | auto fwrap = [f](TObjectRef ref, Args... params) { |
285 | const TNode* target = ref.operator->(); |
286 | // call method pointer |
287 | return (target->*f)(params...); |
288 | }; |
289 | return set_body(TypedPackedFunc<R(TObjectRef, Args...)>(fwrap, name_)); |
290 | } |
291 | |
292 | /*! |
293 | * \brief Register a function with given name |
294 | * \param name The name of the function. |
295 | * \param override Whether allow override existing function. |
296 | * \return Reference to the registry. |
297 | */ |
298 | TVM_DLL static Registry& Register(const std::string& name, bool override = false); // NOLINT(*) |
299 | /*! |
300 | * \brief Erase global function from registry, if exist. |
301 | * \param name The name of the function. |
302 | * \return Whether function exist. |
303 | */ |
304 | TVM_DLL static bool Remove(const std::string& name); |
305 | /*! |
306 | * \brief Get the global function by name. |
307 | * \param name The name of the function. |
308 | * \return pointer to the registered function, |
309 | * nullptr if it does not exist. |
310 | */ |
311 | TVM_DLL static const PackedFunc* Get(const std::string& name); // NOLINT(*) |
312 | /*! |
313 | * \brief Get the names of currently registered global function. |
314 | * \return The names |
315 | */ |
316 | TVM_DLL static std::vector<std::string> ListNames(); |
317 | |
318 | // Internal class. |
319 | struct Manager; |
320 | |
321 | protected: |
322 | /*! \brief name of the function */ |
323 | std::string name_; |
324 | /*! \brief internal packed function */ |
325 | PackedFunc func_; |
326 | friend struct Manager; |
327 | }; |
328 | |
329 | #define TVM_FUNC_REG_VAR_DEF static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::Registry& __mk_##TVM |
330 | |
331 | /*! |
332 | * \brief Register a function globally. |
333 | * \code |
334 | * TVM_REGISTER_GLOBAL("MyPrint") |
335 | * .set_body([](TVMArgs args, TVMRetValue* rv) { |
336 | * }); |
337 | * \endcode |
338 | */ |
339 | #define TVM_REGISTER_GLOBAL(OpName) \ |
340 | TVM_STR_CONCAT(TVM_FUNC_REG_VAR_DEF, __COUNTER__) = ::tvm::runtime::Registry::Register(OpName) |
341 | |
342 | #define TVM_STRINGIZE_DETAIL(x) #x |
343 | #define TVM_STRINGIZE(x) TVM_STRINGIZE_DETAIL(x) |
344 | #define TVM_DESCRIBE(...) describe(__VA_ARGS__ "\n\nFrom:" __FILE__ ":" TVM_STRINGIZE(__LINE__)) |
345 | /*! |
346 | * \brief Macro to include current line as string |
347 | */ |
348 | #define TVM_ADD_FILELINE "\n\nDefined in " __FILE__ ":L" TVM_STRINGIZE(__LINE__) |
349 | |
350 | } // namespace runtime |
351 | } // namespace tvm |
352 | #endif // TVM_RUNTIME_REGISTRY_H_ |
353 | |