1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/runtime/const_loader_module.cc |
22 | * \brief A wrapper for initializing imported modules using constant NDArray. This |
23 | * module is intended to be used by various runtime in the TVM stack, i.e. |
24 | * graph executor, relay VM, AOT runtime, and various user defined runtimes. It |
25 | * paves the way to separate the code and metedata, which makes compilation |
26 | * and/or interpretation more convenient. In addition, the clear separation of |
27 | * code and constants significantly reduces the efforts for handling external |
28 | * codegen and runtimes. |
29 | */ |
30 | #include <tvm/runtime/container/array.h> |
31 | #include <tvm/runtime/container/string.h> |
32 | #include <tvm/runtime/ndarray.h> |
33 | #include <tvm/runtime/packed_func.h> |
34 | #include <tvm/runtime/registry.h> |
35 | |
36 | #include <cstdint> |
37 | #include <sstream> |
38 | |
39 | #include "meta_data.h" |
40 | |
41 | namespace tvm { |
42 | namespace runtime { |
43 | |
44 | /*! |
45 | * \brief The const-loader module is designed to manage initialization of the |
46 | * imported submodules for the C++ runtime. |
47 | */ |
48 | class ConstLoaderModuleNode : public ModuleNode { |
49 | public: |
50 | ConstLoaderModuleNode( |
51 | const std::unordered_map<std::string, NDArray>& const_var_ndarray, |
52 | const std::unordered_map<std::string, std::vector<std::string>>& const_vars_by_symbol) |
53 | : const_var_ndarray_(const_var_ndarray), const_vars_by_symbol_(const_vars_by_symbol) { |
54 | VLOG(1) << "Creating ConstLoaderModule" ; |
55 | // Only the related submodules are cached to reduce the number of runtime |
56 | // symbol lookup for initialization. Otherwise, symbols/primitives in the |
57 | // DSO module will also be cached but they never need to be initialized. |
58 | for (const auto& kv : const_vars_by_symbol_) { |
59 | for (const auto& var : kv.second) { |
60 | VLOG(1) << "ConstLoaderModuleNode has constant '" << var << "' for function '" << kv.first |
61 | << "'" ; |
62 | ICHECK_GT(const_var_ndarray_.count(var), 0) |
63 | << "ConstLoaderModuleNode is missing entry for constant '" << var << "' for function '" |
64 | << kv.first << "'" ; |
65 | } |
66 | initialized_[kv.first] = false; |
67 | } |
68 | } |
69 | |
70 | PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final { |
71 | VLOG(1) << "ConstLoaderModuleNode::GetFunction(" << name << ")" ; |
72 | // Initialize and memoize the module. |
73 | // Usually, we have some warmup runs. The module initialization should be |
74 | // done at this stage. Therefore, runtime overhead is not a concern. |
75 | if (initialized_.count(name) && !initialized_.at(name)) { |
76 | this->InitSubModule(name); |
77 | initialized_[name] = true; |
78 | } |
79 | |
80 | if (name == "get_const_var_ndarray" ) { |
81 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
82 | Map<String, ObjectRef> ret_map; |
83 | for (const auto& kv : const_var_ndarray_) { |
84 | ret_map.Set(kv.first, kv.second); |
85 | } |
86 | *rv = ret_map; |
87 | }); |
88 | } |
89 | |
90 | // Run the module. |
91 | // Normally we would only have a limited number of submodules. The runtime |
92 | // symobl lookup overhead should be minimal. |
93 | ICHECK(!this->imports().empty()); |
94 | for (Module it : this->imports()) { |
95 | PackedFunc pf = it.GetFunction(name); |
96 | if (pf != nullptr) return pf; |
97 | } |
98 | return PackedFunc(nullptr); |
99 | } |
100 | |
101 | const char* type_key() const final { return "const_loader" ; } |
102 | |
103 | /*! |
104 | * \brief Get the list of constants that is required by the given module. |
105 | * \param symbol The symbol that is being queried. |
106 | * \return The list of needed NDArray. |
107 | */ |
108 | Array<NDArray> GetRequiredConstants(const std::string& symbol) { |
109 | Array<NDArray> ret; |
110 | ICHECK_GT(const_vars_by_symbol_.count(symbol), 0U) |
111 | << "No constants known for function '" << symbol << "'" ; |
112 | std::vector<std::string> vars = const_vars_by_symbol_[symbol]; |
113 | for (const auto& var : vars) { |
114 | ICHECK_GT(const_var_ndarray_.count(var), 0U) |
115 | << "No such constant variable '" << var << "' for function '" << symbol << "'" ; |
116 | ret.push_back(const_var_ndarray_[var]); |
117 | } |
118 | return ret; |
119 | } |
120 | |
121 | /*! |
122 | * \brief Initialize each imported module. |
123 | * \param symobl The symbol used for initializing a module. It is also used |
124 | * for runtime lookup. |
125 | * |
126 | * \note A module could be like the following: |
127 | * ConstLoaderModuleNode (contains all the constants) |
128 | * - CSourceModule |
129 | * - JSON runtime module |
130 | * |
131 | * The initializer iterates through the imported modules and intilizes the |
132 | * found module accordingly by passing the needed constants into it. |
133 | */ |
134 | void InitSubModule(const std::string& symbol) { |
135 | PackedFunc init(nullptr); |
136 | for (Module it : this->imports()) { |
137 | // Get the initialization function from the imported modules. |
138 | std::string init_name = "__init_" + symbol; |
139 | init = it.GetFunction(init_name, false); |
140 | if (init != nullptr) { |
141 | auto md = GetRequiredConstants(symbol); |
142 | // Initialize the module with constants. |
143 | int ret = init(md); |
144 | // Report the error if initialization is failed. |
145 | ICHECK_EQ(ret, 0) << TVMGetLastError(); |
146 | break; |
147 | } |
148 | } |
149 | } |
150 | |
151 | void SaveToBinary(dmlc::Stream* stream) final { |
152 | std::vector<std::string> variables; |
153 | std::vector<NDArray> const_var_ndarray; |
154 | for (const auto& it : const_var_ndarray_) { |
155 | String var_name = it.first; |
156 | variables.push_back(var_name); |
157 | const_var_ndarray.push_back(it.second); |
158 | } |
159 | |
160 | // Save all variables in the function. |
161 | stream->Write(variables); |
162 | // Save all constant data. |
163 | uint64_t sz = static_cast<uint64_t>(const_var_ndarray.size()); |
164 | stream->Write(sz); |
165 | for (uint64_t i = 0; i < sz; i++) { |
166 | const_var_ndarray[i].Save(stream); |
167 | } |
168 | |
169 | // Save the symbol to list of required constant variables mapping |
170 | std::vector<std::string> symbols; |
171 | std::vector<std::vector<std::string>> const_vars; |
172 | for (const auto& it : const_vars_by_symbol_) { |
173 | symbols.push_back(it.first); |
174 | const_vars.push_back(it.second); |
175 | } |
176 | |
177 | stream->Write(symbols); |
178 | sz = static_cast<uint64_t>(const_vars_by_symbol_.size()); |
179 | stream->Write(sz); |
180 | for (uint64_t i = 0; i < sz; i++) { |
181 | stream->Write(const_vars[i]); |
182 | } |
183 | } |
184 | |
185 | static Module LoadFromBinary(void* strm) { |
186 | dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm); |
187 | |
188 | // Load the variables. |
189 | std::vector<std::string> variables; |
190 | ICHECK(stream->Read(&variables)) << "Loading variable names failed" ; |
191 | uint64_t sz; |
192 | ICHECK(stream->Read(&sz, sizeof(sz))) << "Loading number of vars failed" ; |
193 | ICHECK_EQ(static_cast<size_t>(sz), variables.size()) |
194 | << "The number of variables and ndarray counts must match" ; |
195 | // Load the list of ndarray. |
196 | std::vector<NDArray> arrays; |
197 | for (uint64_t i = 0; i < sz; i++) { |
198 | NDArray temp; |
199 | temp.Load(stream); |
200 | arrays.push_back(temp); |
201 | } |
202 | |
203 | std::unordered_map<std::string, NDArray> const_var_ndarray; |
204 | for (uint64_t i = 0; i < sz; i++) { |
205 | ICHECK_EQ(const_var_ndarray.count(variables[i]), 0U); |
206 | const_var_ndarray[variables[i]] = arrays[i]; |
207 | } |
208 | |
209 | // Load the symbol to list of required constant variables mapping |
210 | std::vector<std::string> symbols; |
211 | ICHECK(stream->Read(&symbols)) << "Loading symbols failed" ; |
212 | ICHECK(stream->Read(&sz, sizeof(sz))) << "Loading number of symbols failed" ; |
213 | ICHECK_EQ(static_cast<size_t>(sz), symbols.size()); |
214 | std::vector<std::vector<std::string>> const_vars; |
215 | for (uint64_t i = 0; i < sz; i++) { |
216 | std::vector<std::string> vars; |
217 | ICHECK(stream->Read(&vars)) << "Loading const variables failed" ; |
218 | const_vars.push_back(vars); |
219 | } |
220 | |
221 | std::unordered_map<std::string, std::vector<std::string>> const_vars_by_symbol; |
222 | for (uint64_t i = 0; i < sz; i++) { |
223 | const_vars_by_symbol[symbols[i]] = const_vars[i]; |
224 | } |
225 | |
226 | auto n = make_object<ConstLoaderModuleNode>(const_var_ndarray, const_vars_by_symbol); |
227 | return Module(n); |
228 | } |
229 | |
230 | private: |
231 | /*! |
232 | * \brief Record if a module is initialized. It is needed by imported |
233 | * modules using execution engine. |
234 | */ |
235 | std::unordered_map<std::string, bool> initialized_; |
236 | /*! \brief Variable name to NDArray mapping. */ |
237 | std::unordered_map<std::string, NDArray> const_var_ndarray_; |
238 | /*! \brief Symbol name to required constant variables mapping. */ |
239 | std::unordered_map<std::string, std::vector<std::string>> const_vars_by_symbol_; |
240 | }; |
241 | |
242 | Module ConstLoaderModuleCreate( |
243 | const std::unordered_map<std::string, NDArray>& const_var_ndarray, |
244 | const std::unordered_map<std::string, std::vector<std::string>>& const_vars_by_symbol) { |
245 | auto n = make_object<ConstLoaderModuleNode>(const_var_ndarray, const_vars_by_symbol); |
246 | return Module(n); |
247 | } |
248 | |
249 | TVM_REGISTER_GLOBAL("runtime.module.loadbinary_metadata" ) |
250 | .set_body_typed(ConstLoaderModuleNode::LoadFromBinary); |
251 | TVM_REGISTER_GLOBAL("runtime.module.loadbinary_const_loader" ) |
252 | .set_body_typed(ConstLoaderModuleNode::LoadFromBinary); |
253 | |
254 | } // namespace runtime |
255 | } // namespace tvm |
256 | |