1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file module_util.cc
22 * \brief Utilities for module.
23 */
24#include "library_module.h"
25
26#include <dmlc/memory_io.h>
27#include <tvm/runtime/module.h>
28#include <tvm/runtime/registry.h>
29
30#include <string>
31#include <utility>
32#include <vector>
33
34namespace tvm {
35namespace runtime {
36
37// Library module that exposes symbols from a library.
38class LibraryModuleNode final : public ModuleNode {
39 public:
40 explicit LibraryModuleNode(ObjectPtr<Library> lib, PackedFuncWrapper wrapper)
41 : lib_(lib), packed_func_wrapper_(wrapper) {}
42
43 const char* type_key() const final { return "library"; }
44
45 PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
46 TVMBackendPackedCFunc faddr;
47 if (name == runtime::symbol::tvm_module_main) {
48 const char* entry_name =
49 reinterpret_cast<const char*>(lib_->GetSymbol(runtime::symbol::tvm_module_main));
50 ICHECK(entry_name != nullptr)
51 << "Symbol " << runtime::symbol::tvm_module_main << " is not presented";
52 faddr = reinterpret_cast<TVMBackendPackedCFunc>(lib_->GetSymbol(entry_name));
53 } else {
54 faddr = reinterpret_cast<TVMBackendPackedCFunc>(lib_->GetSymbol(name.c_str()));
55 }
56 if (faddr == nullptr) return PackedFunc();
57 return packed_func_wrapper_(faddr, sptr_to_self);
58 }
59
60 private:
61 ObjectPtr<Library> lib_;
62 PackedFuncWrapper packed_func_wrapper_;
63};
64
65/*!
66 * \brief Helper classes to get into internal of a module.
67 */
68class ModuleInternal {
69 public:
70 // Get mutable reference of imports.
71 static std::vector<Module>* GetImportsAddr(ModuleNode* node) { return &(node->imports_); }
72};
73
74PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr<Object>& sptr_to_self) {
75 return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
76 TVMValue ret_value;
77 int ret_type_code = kTVMNullptr;
78 int ret = (*faddr)(const_cast<TVMValue*>(args.values), const_cast<int*>(args.type_codes),
79 args.num_args, &ret_value, &ret_type_code, nullptr);
80 ICHECK_EQ(ret, 0) << TVMGetLastError();
81 if (ret_type_code != kTVMNullptr) {
82 *rv = TVMRetValue::MoveFromCHost(ret_value, ret_type_code);
83 }
84 });
85}
86
87void InitContextFunctions(std::function<void*(const char*)> fgetsymbol) {
88#define TVM_INIT_CONTEXT_FUNC(FuncName) \
89 if (auto* fp = reinterpret_cast<decltype(&FuncName)*>(fgetsymbol("__" #FuncName))) { \
90 *fp = FuncName; \
91 }
92 // Initialize the functions
93 TVM_INIT_CONTEXT_FUNC(TVMFuncCall);
94 TVM_INIT_CONTEXT_FUNC(TVMAPISetLastError);
95 TVM_INIT_CONTEXT_FUNC(TVMBackendGetFuncFromEnv);
96 TVM_INIT_CONTEXT_FUNC(TVMBackendAllocWorkspace);
97 TVM_INIT_CONTEXT_FUNC(TVMBackendFreeWorkspace);
98 TVM_INIT_CONTEXT_FUNC(TVMBackendParallelLaunch);
99 TVM_INIT_CONTEXT_FUNC(TVMBackendParallelBarrier);
100
101#undef TVM_INIT_CONTEXT_FUNC
102}
103
104Module LoadModuleFromBinary(const std::string& type_key, dmlc::Stream* stream) {
105 std::string loadkey = "runtime.module.loadbinary_";
106 std::string fkey = loadkey + type_key;
107 const PackedFunc* f = Registry::Get(fkey);
108 if (f == nullptr) {
109 std::string loaders = "";
110 for (auto name : Registry::ListNames()) {
111 if (name.find(loadkey, 0) == 0) {
112 if (loaders.size() > 0) {
113 loaders += ", ";
114 }
115 loaders += name.substr(loadkey.size());
116 }
117 }
118 LOG(FATAL) << "Binary was created using {" << type_key
119 << "} but a loader of that name is not registered. Available loaders are " << loaders
120 << ". Perhaps you need to recompile with this runtime enabled.";
121 }
122
123 return (*f)(static_cast<void*>(stream));
124}
125
126/*!
127 * \brief Load and append module blob to module list
128 * \param mblob The module blob.
129 * \param lib The library.
130 * \param root_module the output root module
131 * \param dso_ctx_addr the output dso module
132 */
133void ProcessModuleBlob(const char* mblob, ObjectPtr<Library> lib,
134 PackedFuncWrapper packed_func_wrapper, runtime::Module* root_module,
135 runtime::ModuleNode** dso_ctx_addr = nullptr) {
136 ICHECK(mblob != nullptr);
137 uint64_t nbytes = 0;
138 for (size_t i = 0; i < sizeof(nbytes); ++i) {
139 uint64_t c = mblob[i];
140 nbytes |= (c & 0xffUL) << (i * 8);
141 }
142 dmlc::MemoryFixedSizeStream fs(const_cast<char*>(mblob + sizeof(nbytes)),
143 static_cast<size_t>(nbytes));
144 dmlc::Stream* stream = &fs;
145 uint64_t size;
146 ICHECK(stream->Read(&size));
147 std::vector<Module> modules;
148 std::vector<uint64_t> import_tree_row_ptr;
149 std::vector<uint64_t> import_tree_child_indices;
150 int num_dso_module = 0;
151
152 for (uint64_t i = 0; i < size; ++i) {
153 std::string tkey;
154 ICHECK(stream->Read(&tkey));
155 // "_lib" serves as a placeholder in the module import tree to indicate where
156 // to place the DSOModule
157 if (tkey == "_lib") {
158 auto dso_module = Module(make_object<LibraryModuleNode>(lib, packed_func_wrapper));
159 *dso_ctx_addr = dso_module.operator->();
160 ++num_dso_module;
161 modules.emplace_back(dso_module);
162 ICHECK_EQ(num_dso_module, 1U) << "Multiple dso module detected, please upgrade tvm "
163 << " to the latest before exporting the module";
164 } else if (tkey == "_import_tree") {
165 ICHECK(stream->Read(&import_tree_row_ptr));
166 ICHECK(stream->Read(&import_tree_child_indices));
167 } else {
168 auto m = LoadModuleFromBinary(tkey, stream);
169 modules.emplace_back(m);
170 }
171 }
172
173 // if we are using old dll, we don't have import tree
174 // so that we can't reconstruct module relationship using import tree
175 if (import_tree_row_ptr.empty()) {
176 auto n = make_object<LibraryModuleNode>(lib, packed_func_wrapper);
177 auto module_import_addr = ModuleInternal::GetImportsAddr(n.operator->());
178 for (const auto& m : modules) {
179 module_import_addr->emplace_back(m);
180 }
181 *dso_ctx_addr = n.get();
182 *root_module = Module(n);
183 } else {
184 for (size_t i = 0; i < modules.size(); ++i) {
185 for (size_t j = import_tree_row_ptr[i]; j < import_tree_row_ptr[i + 1]; ++j) {
186 auto module_import_addr = ModuleInternal::GetImportsAddr(modules[i].operator->());
187 auto child_index = import_tree_child_indices[j];
188 ICHECK(child_index < modules.size());
189 module_import_addr->emplace_back(modules[child_index]);
190 }
191 }
192
193 ICHECK(!modules.empty()) << "modules cannot be empty when import tree is present";
194 // invariance: root module is always at location 0.
195 // The module order is collected via DFS
196 *root_module = modules[0];
197 }
198}
199
200Module CreateModuleFromLibrary(ObjectPtr<Library> lib, PackedFuncWrapper packed_func_wrapper) {
201 InitContextFunctions([lib](const char* fname) { return lib->GetSymbol(fname); });
202 auto n = make_object<LibraryModuleNode>(lib, packed_func_wrapper);
203 // Load the imported modules
204 const char* dev_mblob =
205 reinterpret_cast<const char*>(lib->GetSymbol(runtime::symbol::tvm_dev_mblob));
206
207 Module root_mod;
208 runtime::ModuleNode* dso_ctx_addr = nullptr;
209 if (dev_mblob != nullptr) {
210 ProcessModuleBlob(dev_mblob, lib, packed_func_wrapper, &root_mod, &dso_ctx_addr);
211 } else {
212 // Only have one single DSO Module
213 root_mod = Module(n);
214 dso_ctx_addr = root_mod.operator->();
215 }
216
217 // allow lookup of symbol from root (so all symbols are visible).
218 if (auto* ctx_addr = reinterpret_cast<void**>(lib->GetSymbol(runtime::symbol::tvm_module_ctx))) {
219 *ctx_addr = dso_ctx_addr;
220 }
221
222 return root_mod;
223}
224
225TVM_REGISTER_GLOBAL("runtime.module.loadfile_so").set_body([](TVMArgs args, TVMRetValue* rv) {
226 ObjectPtr<Library> n = CreateDSOLibraryObject(args[0]);
227 *rv = CreateModuleFromLibrary(n);
228});
229} // namespace runtime
230} // namespace tvm
231