1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file registry.cc |
22 | * \brief The global registry of packed function. |
23 | */ |
24 | #include <dmlc/thread_local.h> |
25 | #include <tvm/runtime/c_backend_api.h> |
26 | #include <tvm/runtime/logging.h> |
27 | #include <tvm/runtime/registry.h> |
28 | |
29 | #include <array> |
30 | #include <memory> |
31 | #include <mutex> |
32 | #include <unordered_map> |
33 | |
34 | #include "runtime_base.h" |
35 | |
36 | namespace tvm { |
37 | namespace runtime { |
38 | |
39 | struct Registry::Manager { |
40 | // map storing the functions. |
41 | // We deliberately used raw pointer. |
42 | // This is because PackedFunc can contain callbacks into the host language (Python) and the |
43 | // resource can become invalid because of indeterministic order of destruction and forking. |
44 | // The resources will only be recycled during program exit. |
45 | std::unordered_map<std::string, Registry*> fmap; |
46 | // mutex |
47 | std::mutex mutex; |
48 | |
49 | Manager() {} |
50 | |
51 | static Manager* Global() { |
52 | // We deliberately leak the Manager instance, to avoid leak sanitizers |
53 | // complaining about the entries in Manager::fmap being leaked at program |
54 | // exit. |
55 | static Manager* inst = new Manager(); |
56 | return inst; |
57 | } |
58 | }; |
59 | |
60 | Registry& Registry::set_body(PackedFunc f) { // NOLINT(*) |
61 | func_ = f; |
62 | return *this; |
63 | } |
64 | |
65 | Registry& Registry::Register(const std::string& name, bool can_override) { // NOLINT(*) |
66 | Manager* m = Manager::Global(); |
67 | std::lock_guard<std::mutex> lock(m->mutex); |
68 | if (m->fmap.count(name)) { |
69 | ICHECK(can_override) << "Global PackedFunc " << name << " is already registered" ; |
70 | } |
71 | |
72 | Registry* r = new Registry(); |
73 | r->name_ = name; |
74 | m->fmap[name] = r; |
75 | return *r; |
76 | } |
77 | |
78 | bool Registry::Remove(const std::string& name) { |
79 | Manager* m = Manager::Global(); |
80 | std::lock_guard<std::mutex> lock(m->mutex); |
81 | auto it = m->fmap.find(name); |
82 | if (it == m->fmap.end()) return false; |
83 | m->fmap.erase(it); |
84 | return true; |
85 | } |
86 | |
87 | const PackedFunc* Registry::Get(const std::string& name) { |
88 | Manager* m = Manager::Global(); |
89 | std::lock_guard<std::mutex> lock(m->mutex); |
90 | auto it = m->fmap.find(name); |
91 | if (it == m->fmap.end()) return nullptr; |
92 | return &(it->second->func_); |
93 | } |
94 | |
95 | std::vector<std::string> Registry::ListNames() { |
96 | Manager* m = Manager::Global(); |
97 | std::lock_guard<std::mutex> lock(m->mutex); |
98 | std::vector<std::string> keys; |
99 | keys.reserve(m->fmap.size()); |
100 | for (const auto& kv : m->fmap) { |
101 | keys.push_back(kv.first); |
102 | } |
103 | return keys; |
104 | } |
105 | |
106 | /*! |
107 | * \brief Execution environment specific API registry. |
108 | * |
109 | * This registry stores C API function pointers about |
110 | * execution environment(e.g. python) specific API function that |
111 | * we need for specific low-level handling(e.g. signal checking). |
112 | * |
113 | * We only stores the C API function when absolutely necessary (e.g. when signal handler |
114 | * cannot trap back into python). Always consider use the PackedFunc FFI when possible |
115 | * in other cases. |
116 | */ |
117 | class EnvCAPIRegistry { |
118 | public: |
119 | /*! |
120 | * \brief Callback to check if signals have been sent to the process and |
121 | * if so invoke the registered signal handler in the frontend environment. |
122 | * |
123 | * When running TVM in another language (Python), the signal handler |
124 | * may not be immediately executed, but instead the signal is marked |
125 | * in the interpreter state (to ensure non-blocking of the signal handler). |
126 | * |
127 | * \return 0 if no error happens, -1 if error happens. |
128 | */ |
129 | typedef int (*F_PyErr_CheckSignals)(); |
130 | |
131 | // NOTE: the following function are only registered |
132 | // in a python environment. |
133 | /*! |
134 | * \brief PyErr_CheckSignal function |
135 | */ |
136 | F_PyErr_CheckSignals pyerr_check_signals = nullptr; |
137 | |
138 | static EnvCAPIRegistry* Global() { |
139 | static EnvCAPIRegistry* inst = new EnvCAPIRegistry(); |
140 | return inst; |
141 | } |
142 | |
143 | // register environment(e.g. python) specific api functions |
144 | void Register(const std::string& symbol_name, void* fptr) { |
145 | if (symbol_name == "PyErr_CheckSignals" ) { |
146 | Update(symbol_name, &pyerr_check_signals, fptr); |
147 | } else { |
148 | LOG(FATAL) << "Unknown env API " << symbol_name; |
149 | } |
150 | } |
151 | |
152 | // implementation of tvm::runtime::EnvCheckSignals |
153 | void CheckSignals() { |
154 | // check python signal to see if there are exception raised |
155 | if (pyerr_check_signals != nullptr && (*pyerr_check_signals)() != 0) { |
156 | // The error will let FFI know that the frontend environment |
157 | // already set an error. |
158 | throw EnvErrorAlreadySet("" ); |
159 | } |
160 | } |
161 | |
162 | private: |
163 | // update the internal API table |
164 | template <typename FType> |
165 | void Update(const std::string& symbol_name, FType* target, void* ptr) { |
166 | FType ptr_casted = reinterpret_cast<FType>(ptr); |
167 | if (target[0] != nullptr && target[0] != ptr_casted) { |
168 | LOG(WARNING) << "tvm.runtime.RegisterEnvCAPI overrides an existing function " << symbol_name; |
169 | } |
170 | target[0] = ptr_casted; |
171 | } |
172 | }; |
173 | |
174 | void EnvCheckSignals() { EnvCAPIRegistry::Global()->CheckSignals(); } |
175 | |
176 | } // namespace runtime |
177 | } // namespace tvm |
178 | |
179 | /*! \brief entry to easily hold returning information */ |
180 | struct TVMFuncThreadLocalEntry { |
181 | /*! \brief result holder for returning strings */ |
182 | std::vector<std::string> ret_vec_str; |
183 | /*! \brief result holder for returning string pointers */ |
184 | std::vector<const char*> ret_vec_charp; |
185 | }; |
186 | |
187 | /*! \brief Thread local store that can be used to hold return values. */ |
188 | typedef dmlc::ThreadLocalStore<TVMFuncThreadLocalEntry> TVMFuncThreadLocalStore; |
189 | |
190 | int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f, int override) { |
191 | API_BEGIN(); |
192 | using tvm::runtime::GetRef; |
193 | using tvm::runtime::PackedFunc; |
194 | using tvm::runtime::PackedFuncObj; |
195 | tvm::runtime::Registry::Register(name, override != 0) |
196 | .set_body(GetRef<PackedFunc>(static_cast<PackedFuncObj*>(f))); |
197 | API_END(); |
198 | } |
199 | |
200 | int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) { |
201 | API_BEGIN(); |
202 | const tvm::runtime::PackedFunc* fp = tvm::runtime::Registry::Get(name); |
203 | if (fp != nullptr) { |
204 | tvm::runtime::TVMRetValue ret; |
205 | ret = *fp; |
206 | TVMValue val; |
207 | int type_code; |
208 | ret.MoveToCHost(&val, &type_code); |
209 | *out = val.v_handle; |
210 | } else { |
211 | *out = nullptr; |
212 | } |
213 | API_END(); |
214 | } |
215 | |
216 | int TVMFuncListGlobalNames(int* out_size, const char*** out_array) { |
217 | API_BEGIN(); |
218 | TVMFuncThreadLocalEntry* ret = TVMFuncThreadLocalStore::Get(); |
219 | ret->ret_vec_str = tvm::runtime::Registry::ListNames(); |
220 | ret->ret_vec_charp.clear(); |
221 | for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { |
222 | ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); |
223 | } |
224 | *out_array = dmlc::BeginPtr(ret->ret_vec_charp); |
225 | *out_size = static_cast<int>(ret->ret_vec_str.size()); |
226 | API_END(); |
227 | } |
228 | |
229 | int TVMFuncRemoveGlobal(const char* name) { |
230 | API_BEGIN(); |
231 | tvm::runtime::Registry::Remove(name); |
232 | API_END(); |
233 | } |
234 | |
235 | int TVMBackendRegisterEnvCAPI(const char* name, void* ptr) { |
236 | API_BEGIN(); |
237 | tvm::runtime::EnvCAPIRegistry::Global()->Register(name, ptr); |
238 | API_END(); |
239 | } |
240 | |