1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file registry.cc
22 * \brief The global registry of packed function.
23 */
24#include <dmlc/thread_local.h>
25#include <tvm/runtime/c_backend_api.h>
26#include <tvm/runtime/logging.h>
27#include <tvm/runtime/registry.h>
28
29#include <array>
30#include <memory>
31#include <mutex>
32#include <unordered_map>
33
34#include "runtime_base.h"
35
36namespace tvm {
37namespace runtime {
38
39struct Registry::Manager {
40 // map storing the functions.
41 // We deliberately used raw pointer.
42 // This is because PackedFunc can contain callbacks into the host language (Python) and the
43 // resource can become invalid because of indeterministic order of destruction and forking.
44 // The resources will only be recycled during program exit.
45 std::unordered_map<std::string, Registry*> fmap;
46 // mutex
47 std::mutex mutex;
48
49 Manager() {}
50
51 static Manager* Global() {
52 // We deliberately leak the Manager instance, to avoid leak sanitizers
53 // complaining about the entries in Manager::fmap being leaked at program
54 // exit.
55 static Manager* inst = new Manager();
56 return inst;
57 }
58};
59
60Registry& Registry::set_body(PackedFunc f) { // NOLINT(*)
61 func_ = f;
62 return *this;
63}
64
65Registry& Registry::Register(const std::string& name, bool can_override) { // NOLINT(*)
66 Manager* m = Manager::Global();
67 std::lock_guard<std::mutex> lock(m->mutex);
68 if (m->fmap.count(name)) {
69 ICHECK(can_override) << "Global PackedFunc " << name << " is already registered";
70 }
71
72 Registry* r = new Registry();
73 r->name_ = name;
74 m->fmap[name] = r;
75 return *r;
76}
77
78bool Registry::Remove(const std::string& name) {
79 Manager* m = Manager::Global();
80 std::lock_guard<std::mutex> lock(m->mutex);
81 auto it = m->fmap.find(name);
82 if (it == m->fmap.end()) return false;
83 m->fmap.erase(it);
84 return true;
85}
86
87const PackedFunc* Registry::Get(const std::string& name) {
88 Manager* m = Manager::Global();
89 std::lock_guard<std::mutex> lock(m->mutex);
90 auto it = m->fmap.find(name);
91 if (it == m->fmap.end()) return nullptr;
92 return &(it->second->func_);
93}
94
95std::vector<std::string> Registry::ListNames() {
96 Manager* m = Manager::Global();
97 std::lock_guard<std::mutex> lock(m->mutex);
98 std::vector<std::string> keys;
99 keys.reserve(m->fmap.size());
100 for (const auto& kv : m->fmap) {
101 keys.push_back(kv.first);
102 }
103 return keys;
104}
105
106/*!
107 * \brief Execution environment specific API registry.
108 *
109 * This registry stores C API function pointers about
110 * execution environment(e.g. python) specific API function that
111 * we need for specific low-level handling(e.g. signal checking).
112 *
113 * We only stores the C API function when absolutely necessary (e.g. when signal handler
114 * cannot trap back into python). Always consider use the PackedFunc FFI when possible
115 * in other cases.
116 */
117class EnvCAPIRegistry {
118 public:
119 /*!
120 * \brief Callback to check if signals have been sent to the process and
121 * if so invoke the registered signal handler in the frontend environment.
122 *
123 * When running TVM in another language (Python), the signal handler
124 * may not be immediately executed, but instead the signal is marked
125 * in the interpreter state (to ensure non-blocking of the signal handler).
126 *
127 * \return 0 if no error happens, -1 if error happens.
128 */
129 typedef int (*F_PyErr_CheckSignals)();
130
131 // NOTE: the following function are only registered
132 // in a python environment.
133 /*!
134 * \brief PyErr_CheckSignal function
135 */
136 F_PyErr_CheckSignals pyerr_check_signals = nullptr;
137
138 static EnvCAPIRegistry* Global() {
139 static EnvCAPIRegistry* inst = new EnvCAPIRegistry();
140 return inst;
141 }
142
143 // register environment(e.g. python) specific api functions
144 void Register(const std::string& symbol_name, void* fptr) {
145 if (symbol_name == "PyErr_CheckSignals") {
146 Update(symbol_name, &pyerr_check_signals, fptr);
147 } else {
148 LOG(FATAL) << "Unknown env API " << symbol_name;
149 }
150 }
151
152 // implementation of tvm::runtime::EnvCheckSignals
153 void CheckSignals() {
154 // check python signal to see if there are exception raised
155 if (pyerr_check_signals != nullptr && (*pyerr_check_signals)() != 0) {
156 // The error will let FFI know that the frontend environment
157 // already set an error.
158 throw EnvErrorAlreadySet("");
159 }
160 }
161
162 private:
163 // update the internal API table
164 template <typename FType>
165 void Update(const std::string& symbol_name, FType* target, void* ptr) {
166 FType ptr_casted = reinterpret_cast<FType>(ptr);
167 if (target[0] != nullptr && target[0] != ptr_casted) {
168 LOG(WARNING) << "tvm.runtime.RegisterEnvCAPI overrides an existing function " << symbol_name;
169 }
170 target[0] = ptr_casted;
171 }
172};
173
174void EnvCheckSignals() { EnvCAPIRegistry::Global()->CheckSignals(); }
175
176} // namespace runtime
177} // namespace tvm
178
179/*! \brief entry to easily hold returning information */
180struct TVMFuncThreadLocalEntry {
181 /*! \brief result holder for returning strings */
182 std::vector<std::string> ret_vec_str;
183 /*! \brief result holder for returning string pointers */
184 std::vector<const char*> ret_vec_charp;
185};
186
187/*! \brief Thread local store that can be used to hold return values. */
188typedef dmlc::ThreadLocalStore<TVMFuncThreadLocalEntry> TVMFuncThreadLocalStore;
189
190int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f, int override) {
191 API_BEGIN();
192 using tvm::runtime::GetRef;
193 using tvm::runtime::PackedFunc;
194 using tvm::runtime::PackedFuncObj;
195 tvm::runtime::Registry::Register(name, override != 0)
196 .set_body(GetRef<PackedFunc>(static_cast<PackedFuncObj*>(f)));
197 API_END();
198}
199
200int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) {
201 API_BEGIN();
202 const tvm::runtime::PackedFunc* fp = tvm::runtime::Registry::Get(name);
203 if (fp != nullptr) {
204 tvm::runtime::TVMRetValue ret;
205 ret = *fp;
206 TVMValue val;
207 int type_code;
208 ret.MoveToCHost(&val, &type_code);
209 *out = val.v_handle;
210 } else {
211 *out = nullptr;
212 }
213 API_END();
214}
215
216int TVMFuncListGlobalNames(int* out_size, const char*** out_array) {
217 API_BEGIN();
218 TVMFuncThreadLocalEntry* ret = TVMFuncThreadLocalStore::Get();
219 ret->ret_vec_str = tvm::runtime::Registry::ListNames();
220 ret->ret_vec_charp.clear();
221 for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
222 ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
223 }
224 *out_array = dmlc::BeginPtr(ret->ret_vec_charp);
225 *out_size = static_cast<int>(ret->ret_vec_str.size());
226 API_END();
227}
228
229int TVMFuncRemoveGlobal(const char* name) {
230 API_BEGIN();
231 tvm::runtime::Registry::Remove(name);
232 API_END();
233}
234
235int TVMBackendRegisterEnvCAPI(const char* name, void* ptr) {
236 API_BEGIN();
237 tvm::runtime::EnvCAPIRegistry::Global()->Register(name, ptr);
238 API_END();
239}
240