1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/runtime/const_loader_module.cc
22 * \brief A wrapper for initializing imported modules using constant NDArray. This
23 * module is intended to be used by various runtime in the TVM stack, i.e.
24 * graph executor, relay VM, AOT runtime, and various user defined runtimes. It
25 * paves the way to separate the code and metedata, which makes compilation
26 * and/or interpretation more convenient. In addition, the clear separation of
27 * code and constants significantly reduces the efforts for handling external
28 * codegen and runtimes.
29 */
30#include <tvm/runtime/container/array.h>
31#include <tvm/runtime/container/string.h>
32#include <tvm/runtime/ndarray.h>
33#include <tvm/runtime/packed_func.h>
34#include <tvm/runtime/registry.h>
35
36#include <cstdint>
37#include <sstream>
38
39#include "meta_data.h"
40
41namespace tvm {
42namespace runtime {
43
44/*!
45 * \brief The const-loader module is designed to manage initialization of the
46 * imported submodules for the C++ runtime.
47 */
48class ConstLoaderModuleNode : public ModuleNode {
49 public:
50 ConstLoaderModuleNode(
51 const std::unordered_map<std::string, NDArray>& const_var_ndarray,
52 const std::unordered_map<std::string, std::vector<std::string>>& const_vars_by_symbol)
53 : const_var_ndarray_(const_var_ndarray), const_vars_by_symbol_(const_vars_by_symbol) {
54 VLOG(1) << "Creating ConstLoaderModule";
55 // Only the related submodules are cached to reduce the number of runtime
56 // symbol lookup for initialization. Otherwise, symbols/primitives in the
57 // DSO module will also be cached but they never need to be initialized.
58 for (const auto& kv : const_vars_by_symbol_) {
59 for (const auto& var : kv.second) {
60 VLOG(1) << "ConstLoaderModuleNode has constant '" << var << "' for function '" << kv.first
61 << "'";
62 ICHECK_GT(const_var_ndarray_.count(var), 0)
63 << "ConstLoaderModuleNode is missing entry for constant '" << var << "' for function '"
64 << kv.first << "'";
65 }
66 initialized_[kv.first] = false;
67 }
68 }
69
70 PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
71 VLOG(1) << "ConstLoaderModuleNode::GetFunction(" << name << ")";
72 // Initialize and memoize the module.
73 // Usually, we have some warmup runs. The module initialization should be
74 // done at this stage. Therefore, runtime overhead is not a concern.
75 if (initialized_.count(name) && !initialized_.at(name)) {
76 this->InitSubModule(name);
77 initialized_[name] = true;
78 }
79
80 if (name == "get_const_var_ndarray") {
81 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
82 Map<String, ObjectRef> ret_map;
83 for (const auto& kv : const_var_ndarray_) {
84 ret_map.Set(kv.first, kv.second);
85 }
86 *rv = ret_map;
87 });
88 }
89
90 // Run the module.
91 // Normally we would only have a limited number of submodules. The runtime
92 // symobl lookup overhead should be minimal.
93 ICHECK(!this->imports().empty());
94 for (Module it : this->imports()) {
95 PackedFunc pf = it.GetFunction(name);
96 if (pf != nullptr) return pf;
97 }
98 return PackedFunc(nullptr);
99 }
100
101 const char* type_key() const final { return "const_loader"; }
102
103 /*!
104 * \brief Get the list of constants that is required by the given module.
105 * \param symbol The symbol that is being queried.
106 * \return The list of needed NDArray.
107 */
108 Array<NDArray> GetRequiredConstants(const std::string& symbol) {
109 Array<NDArray> ret;
110 ICHECK_GT(const_vars_by_symbol_.count(symbol), 0U)
111 << "No constants known for function '" << symbol << "'";
112 std::vector<std::string> vars = const_vars_by_symbol_[symbol];
113 for (const auto& var : vars) {
114 ICHECK_GT(const_var_ndarray_.count(var), 0U)
115 << "No such constant variable '" << var << "' for function '" << symbol << "'";
116 ret.push_back(const_var_ndarray_[var]);
117 }
118 return ret;
119 }
120
121 /*!
122 * \brief Initialize each imported module.
123 * \param symobl The symbol used for initializing a module. It is also used
124 * for runtime lookup.
125 *
126 * \note A module could be like the following:
127 * ConstLoaderModuleNode (contains all the constants)
128 * - CSourceModule
129 * - JSON runtime module
130 *
131 * The initializer iterates through the imported modules and intilizes the
132 * found module accordingly by passing the needed constants into it.
133 */
134 void InitSubModule(const std::string& symbol) {
135 PackedFunc init(nullptr);
136 for (Module it : this->imports()) {
137 // Get the initialization function from the imported modules.
138 std::string init_name = "__init_" + symbol;
139 init = it.GetFunction(init_name, false);
140 if (init != nullptr) {
141 auto md = GetRequiredConstants(symbol);
142 // Initialize the module with constants.
143 int ret = init(md);
144 // Report the error if initialization is failed.
145 ICHECK_EQ(ret, 0) << TVMGetLastError();
146 break;
147 }
148 }
149 }
150
151 void SaveToBinary(dmlc::Stream* stream) final {
152 std::vector<std::string> variables;
153 std::vector<NDArray> const_var_ndarray;
154 for (const auto& it : const_var_ndarray_) {
155 String var_name = it.first;
156 variables.push_back(var_name);
157 const_var_ndarray.push_back(it.second);
158 }
159
160 // Save all variables in the function.
161 stream->Write(variables);
162 // Save all constant data.
163 uint64_t sz = static_cast<uint64_t>(const_var_ndarray.size());
164 stream->Write(sz);
165 for (uint64_t i = 0; i < sz; i++) {
166 const_var_ndarray[i].Save(stream);
167 }
168
169 // Save the symbol to list of required constant variables mapping
170 std::vector<std::string> symbols;
171 std::vector<std::vector<std::string>> const_vars;
172 for (const auto& it : const_vars_by_symbol_) {
173 symbols.push_back(it.first);
174 const_vars.push_back(it.second);
175 }
176
177 stream->Write(symbols);
178 sz = static_cast<uint64_t>(const_vars_by_symbol_.size());
179 stream->Write(sz);
180 for (uint64_t i = 0; i < sz; i++) {
181 stream->Write(const_vars[i]);
182 }
183 }
184
185 static Module LoadFromBinary(void* strm) {
186 dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
187
188 // Load the variables.
189 std::vector<std::string> variables;
190 ICHECK(stream->Read(&variables)) << "Loading variable names failed";
191 uint64_t sz;
192 ICHECK(stream->Read(&sz, sizeof(sz))) << "Loading number of vars failed";
193 ICHECK_EQ(static_cast<size_t>(sz), variables.size())
194 << "The number of variables and ndarray counts must match";
195 // Load the list of ndarray.
196 std::vector<NDArray> arrays;
197 for (uint64_t i = 0; i < sz; i++) {
198 NDArray temp;
199 temp.Load(stream);
200 arrays.push_back(temp);
201 }
202
203 std::unordered_map<std::string, NDArray> const_var_ndarray;
204 for (uint64_t i = 0; i < sz; i++) {
205 ICHECK_EQ(const_var_ndarray.count(variables[i]), 0U);
206 const_var_ndarray[variables[i]] = arrays[i];
207 }
208
209 // Load the symbol to list of required constant variables mapping
210 std::vector<std::string> symbols;
211 ICHECK(stream->Read(&symbols)) << "Loading symbols failed";
212 ICHECK(stream->Read(&sz, sizeof(sz))) << "Loading number of symbols failed";
213 ICHECK_EQ(static_cast<size_t>(sz), symbols.size());
214 std::vector<std::vector<std::string>> const_vars;
215 for (uint64_t i = 0; i < sz; i++) {
216 std::vector<std::string> vars;
217 ICHECK(stream->Read(&vars)) << "Loading const variables failed";
218 const_vars.push_back(vars);
219 }
220
221 std::unordered_map<std::string, std::vector<std::string>> const_vars_by_symbol;
222 for (uint64_t i = 0; i < sz; i++) {
223 const_vars_by_symbol[symbols[i]] = const_vars[i];
224 }
225
226 auto n = make_object<ConstLoaderModuleNode>(const_var_ndarray, const_vars_by_symbol);
227 return Module(n);
228 }
229
230 private:
231 /*!
232 * \brief Record if a module is initialized. It is needed by imported
233 * modules using execution engine.
234 */
235 std::unordered_map<std::string, bool> initialized_;
236 /*! \brief Variable name to NDArray mapping. */
237 std::unordered_map<std::string, NDArray> const_var_ndarray_;
238 /*! \brief Symbol name to required constant variables mapping. */
239 std::unordered_map<std::string, std::vector<std::string>> const_vars_by_symbol_;
240};
241
242Module ConstLoaderModuleCreate(
243 const std::unordered_map<std::string, NDArray>& const_var_ndarray,
244 const std::unordered_map<std::string, std::vector<std::string>>& const_vars_by_symbol) {
245 auto n = make_object<ConstLoaderModuleNode>(const_var_ndarray, const_vars_by_symbol);
246 return Module(n);
247}
248
249TVM_REGISTER_GLOBAL("runtime.module.loadbinary_metadata")
250 .set_body_typed(ConstLoaderModuleNode::LoadFromBinary);
251TVM_REGISTER_GLOBAL("runtime.module.loadbinary_const_loader")
252 .set_body_typed(ConstLoaderModuleNode::LoadFromBinary);
253
254} // namespace runtime
255} // namespace tvm
256