1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file module_util.cc |
22 | * \brief Utilities for module. |
23 | */ |
24 | #include "library_module.h" |
25 | |
26 | #include <dmlc/memory_io.h> |
27 | #include <tvm/runtime/module.h> |
28 | #include <tvm/runtime/registry.h> |
29 | |
30 | #include <string> |
31 | #include <utility> |
32 | #include <vector> |
33 | |
34 | namespace tvm { |
35 | namespace runtime { |
36 | |
37 | // Library module that exposes symbols from a library. |
38 | class LibraryModuleNode final : public ModuleNode { |
39 | public: |
40 | explicit LibraryModuleNode(ObjectPtr<Library> lib, PackedFuncWrapper wrapper) |
41 | : lib_(lib), packed_func_wrapper_(wrapper) {} |
42 | |
43 | const char* type_key() const final { return "library" ; } |
44 | |
45 | PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final { |
46 | TVMBackendPackedCFunc faddr; |
47 | if (name == runtime::symbol::tvm_module_main) { |
48 | const char* entry_name = |
49 | reinterpret_cast<const char*>(lib_->GetSymbol(runtime::symbol::tvm_module_main)); |
50 | ICHECK(entry_name != nullptr) |
51 | << "Symbol " << runtime::symbol::tvm_module_main << " is not presented" ; |
52 | faddr = reinterpret_cast<TVMBackendPackedCFunc>(lib_->GetSymbol(entry_name)); |
53 | } else { |
54 | faddr = reinterpret_cast<TVMBackendPackedCFunc>(lib_->GetSymbol(name.c_str())); |
55 | } |
56 | if (faddr == nullptr) return PackedFunc(); |
57 | return packed_func_wrapper_(faddr, sptr_to_self); |
58 | } |
59 | |
60 | private: |
61 | ObjectPtr<Library> lib_; |
62 | PackedFuncWrapper packed_func_wrapper_; |
63 | }; |
64 | |
65 | /*! |
66 | * \brief Helper classes to get into internal of a module. |
67 | */ |
68 | class ModuleInternal { |
69 | public: |
70 | // Get mutable reference of imports. |
71 | static std::vector<Module>* GetImportsAddr(ModuleNode* node) { return &(node->imports_); } |
72 | }; |
73 | |
74 | PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr<Object>& sptr_to_self) { |
75 | return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) { |
76 | TVMValue ret_value; |
77 | int ret_type_code = kTVMNullptr; |
78 | int ret = (*faddr)(const_cast<TVMValue*>(args.values), const_cast<int*>(args.type_codes), |
79 | args.num_args, &ret_value, &ret_type_code, nullptr); |
80 | ICHECK_EQ(ret, 0) << TVMGetLastError(); |
81 | if (ret_type_code != kTVMNullptr) { |
82 | *rv = TVMRetValue::MoveFromCHost(ret_value, ret_type_code); |
83 | } |
84 | }); |
85 | } |
86 | |
87 | void InitContextFunctions(std::function<void*(const char*)> fgetsymbol) { |
88 | #define TVM_INIT_CONTEXT_FUNC(FuncName) \ |
89 | if (auto* fp = reinterpret_cast<decltype(&FuncName)*>(fgetsymbol("__" #FuncName))) { \ |
90 | *fp = FuncName; \ |
91 | } |
92 | // Initialize the functions |
93 | TVM_INIT_CONTEXT_FUNC(TVMFuncCall); |
94 | TVM_INIT_CONTEXT_FUNC(TVMAPISetLastError); |
95 | TVM_INIT_CONTEXT_FUNC(TVMBackendGetFuncFromEnv); |
96 | TVM_INIT_CONTEXT_FUNC(TVMBackendAllocWorkspace); |
97 | TVM_INIT_CONTEXT_FUNC(TVMBackendFreeWorkspace); |
98 | TVM_INIT_CONTEXT_FUNC(TVMBackendParallelLaunch); |
99 | TVM_INIT_CONTEXT_FUNC(TVMBackendParallelBarrier); |
100 | |
101 | #undef TVM_INIT_CONTEXT_FUNC |
102 | } |
103 | |
104 | Module LoadModuleFromBinary(const std::string& type_key, dmlc::Stream* stream) { |
105 | std::string loadkey = "runtime.module.loadbinary_" ; |
106 | std::string fkey = loadkey + type_key; |
107 | const PackedFunc* f = Registry::Get(fkey); |
108 | if (f == nullptr) { |
109 | std::string loaders = "" ; |
110 | for (auto name : Registry::ListNames()) { |
111 | if (name.find(loadkey, 0) == 0) { |
112 | if (loaders.size() > 0) { |
113 | loaders += ", " ; |
114 | } |
115 | loaders += name.substr(loadkey.size()); |
116 | } |
117 | } |
118 | LOG(FATAL) << "Binary was created using {" << type_key |
119 | << "} but a loader of that name is not registered. Available loaders are " << loaders |
120 | << ". Perhaps you need to recompile with this runtime enabled." ; |
121 | } |
122 | |
123 | return (*f)(static_cast<void*>(stream)); |
124 | } |
125 | |
126 | /*! |
127 | * \brief Load and append module blob to module list |
128 | * \param mblob The module blob. |
129 | * \param lib The library. |
130 | * \param root_module the output root module |
131 | * \param dso_ctx_addr the output dso module |
132 | */ |
133 | void ProcessModuleBlob(const char* mblob, ObjectPtr<Library> lib, |
134 | PackedFuncWrapper packed_func_wrapper, runtime::Module* root_module, |
135 | runtime::ModuleNode** dso_ctx_addr = nullptr) { |
136 | ICHECK(mblob != nullptr); |
137 | uint64_t nbytes = 0; |
138 | for (size_t i = 0; i < sizeof(nbytes); ++i) { |
139 | uint64_t c = mblob[i]; |
140 | nbytes |= (c & 0xffUL) << (i * 8); |
141 | } |
142 | dmlc::MemoryFixedSizeStream fs(const_cast<char*>(mblob + sizeof(nbytes)), |
143 | static_cast<size_t>(nbytes)); |
144 | dmlc::Stream* stream = &fs; |
145 | uint64_t size; |
146 | ICHECK(stream->Read(&size)); |
147 | std::vector<Module> modules; |
148 | std::vector<uint64_t> import_tree_row_ptr; |
149 | std::vector<uint64_t> import_tree_child_indices; |
150 | int num_dso_module = 0; |
151 | |
152 | for (uint64_t i = 0; i < size; ++i) { |
153 | std::string tkey; |
154 | ICHECK(stream->Read(&tkey)); |
155 | // "_lib" serves as a placeholder in the module import tree to indicate where |
156 | // to place the DSOModule |
157 | if (tkey == "_lib" ) { |
158 | auto dso_module = Module(make_object<LibraryModuleNode>(lib, packed_func_wrapper)); |
159 | *dso_ctx_addr = dso_module.operator->(); |
160 | ++num_dso_module; |
161 | modules.emplace_back(dso_module); |
162 | ICHECK_EQ(num_dso_module, 1U) << "Multiple dso module detected, please upgrade tvm " |
163 | << " to the latest before exporting the module" ; |
164 | } else if (tkey == "_import_tree" ) { |
165 | ICHECK(stream->Read(&import_tree_row_ptr)); |
166 | ICHECK(stream->Read(&import_tree_child_indices)); |
167 | } else { |
168 | auto m = LoadModuleFromBinary(tkey, stream); |
169 | modules.emplace_back(m); |
170 | } |
171 | } |
172 | |
173 | // if we are using old dll, we don't have import tree |
174 | // so that we can't reconstruct module relationship using import tree |
175 | if (import_tree_row_ptr.empty()) { |
176 | auto n = make_object<LibraryModuleNode>(lib, packed_func_wrapper); |
177 | auto module_import_addr = ModuleInternal::GetImportsAddr(n.operator->()); |
178 | for (const auto& m : modules) { |
179 | module_import_addr->emplace_back(m); |
180 | } |
181 | *dso_ctx_addr = n.get(); |
182 | *root_module = Module(n); |
183 | } else { |
184 | for (size_t i = 0; i < modules.size(); ++i) { |
185 | for (size_t j = import_tree_row_ptr[i]; j < import_tree_row_ptr[i + 1]; ++j) { |
186 | auto module_import_addr = ModuleInternal::GetImportsAddr(modules[i].operator->()); |
187 | auto child_index = import_tree_child_indices[j]; |
188 | ICHECK(child_index < modules.size()); |
189 | module_import_addr->emplace_back(modules[child_index]); |
190 | } |
191 | } |
192 | |
193 | ICHECK(!modules.empty()) << "modules cannot be empty when import tree is present" ; |
194 | // invariance: root module is always at location 0. |
195 | // The module order is collected via DFS |
196 | *root_module = modules[0]; |
197 | } |
198 | } |
199 | |
200 | Module CreateModuleFromLibrary(ObjectPtr<Library> lib, PackedFuncWrapper packed_func_wrapper) { |
201 | InitContextFunctions([lib](const char* fname) { return lib->GetSymbol(fname); }); |
202 | auto n = make_object<LibraryModuleNode>(lib, packed_func_wrapper); |
203 | // Load the imported modules |
204 | const char* dev_mblob = |
205 | reinterpret_cast<const char*>(lib->GetSymbol(runtime::symbol::tvm_dev_mblob)); |
206 | |
207 | Module root_mod; |
208 | runtime::ModuleNode* dso_ctx_addr = nullptr; |
209 | if (dev_mblob != nullptr) { |
210 | ProcessModuleBlob(dev_mblob, lib, packed_func_wrapper, &root_mod, &dso_ctx_addr); |
211 | } else { |
212 | // Only have one single DSO Module |
213 | root_mod = Module(n); |
214 | dso_ctx_addr = root_mod.operator->(); |
215 | } |
216 | |
217 | // allow lookup of symbol from root (so all symbols are visible). |
218 | if (auto* ctx_addr = reinterpret_cast<void**>(lib->GetSymbol(runtime::symbol::tvm_module_ctx))) { |
219 | *ctx_addr = dso_ctx_addr; |
220 | } |
221 | |
222 | return root_mod; |
223 | } |
224 | |
225 | TVM_REGISTER_GLOBAL("runtime.module.loadfile_so" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
226 | ObjectPtr<Library> n = CreateDSOLibraryObject(args[0]); |
227 | *rv = CreateModuleFromLibrary(n); |
228 | }); |
229 | } // namespace runtime |
230 | } // namespace tvm |
231 | |