1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/runtime/registry.h
22 * \brief This file defines the TVM global function registry.
23 *
24 * The registered functions will be made available to front-end
25 * as well as backend users.
26 *
27 * The registry stores type-erased functions.
28 * Each registered function is automatically exposed
29 * to front-end language(e.g. python).
30 *
31 * Front-end can also pass callbacks as PackedFunc, or register
32 * then into the same global registry in C++.
33 * The goal is to mix the front-end language and the TVM back-end.
34 *
35 * \code
36 * // register the function as MyAPIFuncName
37 * TVM_REGISTER_GLOBAL(MyAPIFuncName)
38 * .set_body([](TVMArgs args, TVMRetValue* rv) {
39 * // my code.
40 * });
41 * \endcode
42 */
43#ifndef TVM_RUNTIME_REGISTRY_H_
44#define TVM_RUNTIME_REGISTRY_H_
45
46#include <tvm/runtime/packed_func.h>
47
48#include <string>
49#include <type_traits>
50#include <utility>
51#include <vector>
52
53namespace tvm {
54namespace runtime {
55
56/*!
57 * \brief Check if signals have been sent to the process and if so
58 * invoke the registered signal handler in the frontend environment.
59 *
60 * When running TVM in another language (Python), the signal handler
61 * may not be immediately executed, but instead the signal is marked
62 * in the interpreter state (to ensure non-blocking of the signal handler).
63 *
64 * This function can be explicitly invoked to check the cached signal
65 * and run the related processing if a signal is marked.
66 *
67 * On Linux, when siginterrupt() is set, invoke this function whenever a syscall returns EINTR.
68 * When it is not set, invoke it between long-running syscalls when you will not immediately
69 * return to the frontend. On Windows, the same rules apply, but due to differences in signal
70 * processing, these are likely to only make a difference when used with Ctrl+C and socket calls.
71 *
72 * Not inserting this function will not cause any correctness
73 * issue, but will delay invoking the Python-side signal handler until the function returns to
74 * the Python side. This means that the effect of e.g. pressing Ctrl+C or sending signals the
75 * process will be delayed until function return. When a C function is blocked on a syscall
76 * such as accept(), it needs to be called when EINTR is received.
77 * So this function is not needed in most API functions, which can finish quickly in a
78 * reasonable, deterministic amount of time.
79 *
80 * \code
81 *
82 * int check_signal_every_k_iter = 10;
83 *
84 * for (int iter = 0; iter < very_large_number; ++iter) {
85 * if (iter % check_signal_every_k_iter == 0) {
86 * tvm::runtime::EnvCheckSignals();
87 * }
88 * // do work here
89 * }
90 *
91 * \endcode
92 *
93 * \note This function is a nop when no PyErr_CheckSignals is registered.
94 *
95 * \throws This function throws an exception when the frontend signal handler
96 * indicate an error happens, otherwise it returns normally.
97 */
98TVM_DLL void EnvCheckSignals();
99
100/*! \brief Registry for global function */
101class Registry {
102 public:
103 /*!
104 * \brief set the body of the function to be f
105 * \param f The body of the function.
106 */
107 TVM_DLL Registry& set_body(PackedFunc f); // NOLINT(*)
108 /*!
109 * \brief set the body of the function to be f
110 * \param f The body of the function.
111 */
112 template <typename TCallable,
113 typename = typename std::enable_if_t<
114 std::is_convertible<TCallable, std::function<void(TVMArgs, TVMRetValue*)>>::value &&
115 !std::is_base_of<PackedFunc, TCallable>::value>>
116 Registry& set_body(TCallable f) { // NOLINT(*)
117 return set_body(PackedFunc(f));
118 }
119 /*!
120 * \brief set the body of the function to the given function.
121 * Note that this will ignore default arg values and always require all arguments to be
122 * provided.
123 *
124 * \code
125 *
126 * int multiply(int x, int y) {
127 * return x * y;
128 * }
129 *
130 * TVM_REGISTER_GLOBAL("multiply")
131 * .set_body_typed(multiply); // will have type int(int, int)
132 *
133 * // will have type int(int, int)
134 * TVM_REGISTER_GLOBAL("sub")
135 * .set_body_typed([](int a, int b) -> int { return a - b; });
136 *
137 * \endcode
138 *
139 * \param f The function to forward to.
140 * \tparam FLambda The signature of the function.
141 */
142 template <typename FLambda>
143 Registry& set_body_typed(FLambda f) {
144 using FType = typename detail::function_signature<FLambda>::FType;
145 return set_body(TypedPackedFunc<FType>(std::move(f), name_).packed());
146 }
147 /*!
148 * \brief set the body of the function to be the passed method pointer.
149 * Note that this will ignore default arg values and always require all arguments to be
150 * provided.
151 *
152 * \code
153 *
154 * // node subclass:
155 * struct Example {
156 * int doThing(int x);
157 * }
158 * TVM_REGISTER_GLOBAL("Example_doThing")
159 * .set_body_method(&Example::doThing); // will have type int(Example, int)
160 *
161 * \endcode
162 *
163 * \param f the method pointer to forward to.
164 * \tparam T the type containing the method (inferred).
165 * \tparam R the return type of the function (inferred).
166 * \tparam Args the argument types of the function (inferred).
167 */
168 template <typename T, typename R, typename... Args>
169 Registry& set_body_method(R (T::*f)(Args...)) {
170 using R_ = typename std::remove_reference<R>::type;
171 auto fwrap = [f](T target, Args... params) -> R_ {
172 // call method pointer
173 return (target.*f)(params...);
174 };
175 return set_body(TypedPackedFunc<R_(T, Args...)>(fwrap, name_));
176 }
177
178 /*!
179 * \brief set the body of the function to be the passed method pointer.
180 * Note that this will ignore default arg values and always require all arguments to be
181 * provided.
182 *
183 * \code
184 *
185 * // node subclass:
186 * struct Example {
187 * int doThing(int x);
188 * }
189 * TVM_REGISTER_GLOBAL("Example_doThing")
190 * .set_body_method(&Example::doThing); // will have type int(Example, int)
191 *
192 * \endcode
193 *
194 * \param f the method pointer to forward to.
195 * \tparam T the type containing the method (inferred).
196 * \tparam R the return type of the function (inferred).
197 * \tparam Args the argument types of the function (inferred).
198 */
199 template <typename T, typename R, typename... Args>
200 Registry& set_body_method(R (T::*f)(Args...) const) {
201 auto fwrap = [f](const T target, Args... params) -> R {
202 // call method pointer
203 return (target.*f)(params...);
204 };
205 return set_body(TypedPackedFunc<R(const T, Args...)>(fwrap, name_));
206 }
207
208 /*!
209 * \brief set the body of the function to be the passed method pointer.
210 * Used when calling a method on a Node subclass through a ObjectRef subclass.
211 * Note that this will ignore default arg values and always require all arguments to be
212 * provided.
213 *
214 * \code
215 *
216 * // node subclass:
217 * struct ExampleNode: BaseNode {
218 * int doThing(int x);
219 * }
220 *
221 * // noderef subclass
222 * struct Example;
223 *
224 * TVM_REGISTER_GLOBAL("Example_doThing")
225 * .set_body_method<Example>(&ExampleNode::doThing); // will have type int(Example, int)
226 *
227 * // note that just doing:
228 * // .set_body_method(&ExampleNode::doThing);
229 * // wouldn't work, because ExampleNode can't be taken from a TVMArgValue.
230 *
231 * \endcode
232 *
233 * \param f the method pointer to forward to.
234 * \tparam TObjectRef the node reference type to call the method on
235 * \tparam TNode the node type containing the method (inferred).
236 * \tparam R the return type of the function (inferred).
237 * \tparam Args the argument types of the function (inferred).
238 */
239 template <typename TObjectRef, typename TNode, typename R, typename... Args,
240 typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
241 Registry& set_body_method(R (TNode::*f)(Args...)) {
242 auto fwrap = [f](TObjectRef ref, Args... params) {
243 TNode* target = ref.operator->();
244 // call method pointer
245 return (target->*f)(params...);
246 };
247 return set_body(TypedPackedFunc<R(TObjectRef, Args...)>(fwrap, name_));
248 }
249
250 /*!
251 * \brief set the body of the function to be the passed method pointer.
252 * Used when calling a method on a Node subclass through a ObjectRef subclass.
253 * Note that this will ignore default arg values and always require all arguments to be
254 * provided.
255 *
256 * \code
257 *
258 * // node subclass:
259 * struct ExampleNode: BaseNode {
260 * int doThing(int x);
261 * }
262 *
263 * // noderef subclass
264 * struct Example;
265 *
266 * TVM_REGISTER_GLOBAL("Example_doThing")
267 * .set_body_method<Example>(&ExampleNode::doThing); // will have type int(Example, int)
268 *
269 * // note that just doing:
270 * // .set_body_method(&ExampleNode::doThing);
271 * // wouldn't work, because ExampleNode can't be taken from a TVMArgValue.
272 *
273 * \endcode
274 *
275 * \param f the method pointer to forward to.
276 * \tparam TObjectRef the node reference type to call the method on
277 * \tparam TNode the node type containing the method (inferred).
278 * \tparam R the return type of the function (inferred).
279 * \tparam Args the argument types of the function (inferred).
280 */
281 template <typename TObjectRef, typename TNode, typename R, typename... Args,
282 typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
283 Registry& set_body_method(R (TNode::*f)(Args...) const) {
284 auto fwrap = [f](TObjectRef ref, Args... params) {
285 const TNode* target = ref.operator->();
286 // call method pointer
287 return (target->*f)(params...);
288 };
289 return set_body(TypedPackedFunc<R(TObjectRef, Args...)>(fwrap, name_));
290 }
291
292 /*!
293 * \brief Register a function with given name
294 * \param name The name of the function.
295 * \param override Whether allow override existing function.
296 * \return Reference to the registry.
297 */
298 TVM_DLL static Registry& Register(const std::string& name, bool override = false); // NOLINT(*)
299 /*!
300 * \brief Erase global function from registry, if exist.
301 * \param name The name of the function.
302 * \return Whether function exist.
303 */
304 TVM_DLL static bool Remove(const std::string& name);
305 /*!
306 * \brief Get the global function by name.
307 * \param name The name of the function.
308 * \return pointer to the registered function,
309 * nullptr if it does not exist.
310 */
311 TVM_DLL static const PackedFunc* Get(const std::string& name); // NOLINT(*)
312 /*!
313 * \brief Get the names of currently registered global function.
314 * \return The names
315 */
316 TVM_DLL static std::vector<std::string> ListNames();
317
318 // Internal class.
319 struct Manager;
320
321 protected:
322 /*! \brief name of the function */
323 std::string name_;
324 /*! \brief internal packed function */
325 PackedFunc func_;
326 friend struct Manager;
327};
328
329#define TVM_FUNC_REG_VAR_DEF static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::Registry& __mk_##TVM
330
331/*!
332 * \brief Register a function globally.
333 * \code
334 * TVM_REGISTER_GLOBAL("MyPrint")
335 * .set_body([](TVMArgs args, TVMRetValue* rv) {
336 * });
337 * \endcode
338 */
339#define TVM_REGISTER_GLOBAL(OpName) \
340 TVM_STR_CONCAT(TVM_FUNC_REG_VAR_DEF, __COUNTER__) = ::tvm::runtime::Registry::Register(OpName)
341
342#define TVM_STRINGIZE_DETAIL(x) #x
343#define TVM_STRINGIZE(x) TVM_STRINGIZE_DETAIL(x)
344#define TVM_DESCRIBE(...) describe(__VA_ARGS__ "\n\nFrom:" __FILE__ ":" TVM_STRINGIZE(__LINE__))
345/*!
346 * \brief Macro to include current line as string
347 */
348#define TVM_ADD_FILELINE "\n\nDefined in " __FILE__ ":L" TVM_STRINGIZE(__LINE__)
349
350} // namespace runtime
351} // namespace tvm
352#endif // TVM_RUNTIME_REGISTRY_H_
353