1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file metadata_module.cc |
22 | * \brief Defines functions that build MetadataModules for C++ and C runtimes. |
23 | */ |
24 | #include "metadata_module.h" |
25 | |
26 | #include <tvm/relay/runtime.h> |
27 | |
28 | #include <utility> |
29 | #include <vector> |
30 | |
31 | #include "../runtime/const_loader_module.h" |
32 | #include "../runtime/meta_data.h" |
33 | #include "llvm/llvm_module.h" |
34 | #include "source/source_module.h" |
35 | |
36 | namespace tvm { |
37 | namespace codegen { |
38 | |
39 | static runtime::metadata::Metadata ConvertMetaData( |
40 | relay::backend::ExecutorCodegenMetadata metadata); |
41 | |
42 | static runtime::Module CreateCrtMetadataModule( |
43 | runtime::Module target_module, Target target, relay::Runtime runtime, relay::Executor executor, |
44 | relay::backend::ExecutorCodegenMetadata metadata, |
45 | Array<runtime::Module> non_crt_exportable_modules, |
46 | Array<runtime::Module> crt_exportable_modules, |
47 | const std::unordered_map<std::string, runtime::NDArray>& const_var_ndarray) { |
48 | if (!non_crt_exportable_modules.empty()) { |
49 | std::string non_exportable_modules; |
50 | for (unsigned int i = 0; i < non_crt_exportable_modules.size(); i++) { |
51 | if (i > 0) { |
52 | non_exportable_modules += ", " ; |
53 | } |
54 | auto mod = non_crt_exportable_modules[i]; |
55 | auto pf_sym = mod.GetFunction("get_symbol" ); |
56 | if (pf_sym != nullptr) { |
57 | non_exportable_modules += pf_sym().operator std::string(); |
58 | } else { |
59 | non_exportable_modules += |
60 | std::string{"(module type_key=" } + mod->type_key() + std::string{")" }; |
61 | } |
62 | } |
63 | CHECK(false) << "These " << non_crt_exportable_modules.size() |
64 | << " modules are not exportable to C-runtime: " << non_exportable_modules; |
65 | } |
66 | |
67 | if (target->kind->name == "c" ) { |
68 | runtime::metadata::Metadata aot_metadata; |
69 | if (executor->GetAttr<String>("interface-api" , tvm::String("packed" )) == "packed" ) { |
70 | aot_metadata = ConvertMetaData(metadata); |
71 | } |
72 | |
73 | crt_exportable_modules.push_back(target_module); |
74 | target_module = CreateCSourceCrtMetadataModule(crt_exportable_modules, target, runtime, |
75 | metadata, aot_metadata); |
76 | } else if (target->kind->name == "llvm" ) { |
77 | #ifdef TVM_LLVM_VERSION |
78 | crt_exportable_modules.push_back(target_module); |
79 | target_module = CreateLLVMCrtMetadataModule(crt_exportable_modules, target, runtime); |
80 | #else // TVM_LLVM_VERSION |
81 | LOG(FATAL) << "TVM was not built with LLVM enabled." ; |
82 | #endif // TVM_LLVM_VERSION |
83 | } |
84 | |
85 | return target_module; |
86 | } |
87 | |
88 | // TODO(areusch,masahi): Unify metadata representation and remove the need for this function |
89 | static runtime::metadata::Metadata ConvertMetaData( |
90 | relay::backend::ExecutorCodegenMetadata metadata) { |
91 | ICHECK(metadata.defined()); |
92 | ICHECK_NOTNULL(metadata->pool_inputs); |
93 | |
94 | std::vector<runtime::metadata::TensorInfo> inputs; |
95 | for (size_t i = 0; i < metadata->inputs.size(); ++i) { |
96 | auto v = metadata->inputs[i]; |
97 | auto ttype = metadata->input_tensor_types[i]; |
98 | inputs.push_back( |
99 | runtime::metadata::TensorInfo(make_object<target::metadata::InMemoryTensorInfoNode>( |
100 | v->name_hint, relay::backend::ShapeToJSON(ttype->shape), ttype->dtype))); |
101 | } |
102 | |
103 | std::vector<runtime::metadata::TensorInfo> outputs; |
104 | auto output_ttypes = metadata->output_tensor_types; |
105 | for (size_t i = 0; i < output_ttypes.size(); ++i) { |
106 | auto ttype = output_ttypes[i]; |
107 | std::stringstream name; |
108 | name << "output" << i; |
109 | outputs.push_back( |
110 | runtime::metadata::TensorInfo(make_object<target::metadata::InMemoryTensorInfoNode>( |
111 | name.str(), relay::backend::ShapeToJSON(ttype->shape), ttype->dtype))); |
112 | } |
113 | |
114 | std::vector<runtime::metadata::TensorInfo> pools; |
115 | for (size_t i = 0; i < metadata->pools.size(); ++i) { |
116 | auto var = metadata->pools[i]; |
117 | auto api = metadata->pool_inputs.value()[var]; |
118 | if (api->pool_info.as<WorkspacePoolInfoNode>()) { |
119 | pools.push_back( |
120 | runtime::metadata::TensorInfo(make_object<target::metadata::InMemoryTensorInfoNode>( |
121 | var->name_hint, std::vector<int64_t>{api->allocated_size.IntValue()}, |
122 | tvm::runtime::DataType{kDLUInt, 8, 1}))); |
123 | } |
124 | } |
125 | |
126 | std::vector<ConstantInfo> consts; |
127 | for (const auto& kv : metadata->pool_inputs.value()) { |
128 | const auto& api = kv.second; |
129 | if (const auto* pi = api->pool_info.as<ConstantPoolInfoNode>()) { |
130 | if (pi->is_internal) { |
131 | for (const auto ci : pi->constant_info_array) { |
132 | consts.emplace_back(ci->name_hint, ci->byte_offset, ci->data); |
133 | } |
134 | } |
135 | } |
136 | } |
137 | auto n = make_object<target::metadata::InMemoryMetadataNode>( |
138 | runtime::metadata::kMetadataVersion, inputs, outputs, pools, consts, metadata->mod_name); |
139 | |
140 | return runtime::metadata::Metadata(std::move(n)); |
141 | } |
142 | |
143 | static runtime::Module CreateCppMetadataModule( |
144 | runtime::Module target_module, Target target, relay::Runtime runtime, |
145 | relay::backend::ExecutorCodegenMetadata metadata, |
146 | const std::unordered_map<std::string, std::vector<std::string>>& const_vars_by_symbol, |
147 | Array<runtime::Module> non_crt_exportable_modules, |
148 | Array<runtime::Module> crt_exportable_modules, |
149 | const std::unordered_map<std::string, runtime::NDArray>& const_var_ndarray) { |
150 | if (!non_crt_exportable_modules.empty()) { |
151 | runtime::Module const_loader_mod = |
152 | runtime::ConstLoaderModuleCreate(const_var_ndarray, const_vars_by_symbol); |
153 | const_loader_mod.Import(target_module); |
154 | for (const auto& it : non_crt_exportable_modules) { |
155 | const_loader_mod.Import(it); |
156 | } |
157 | target_module = const_loader_mod; |
158 | } |
159 | |
160 | if (metadata.defined()) { |
161 | runtime::metadata::Metadata runtime_metadata = ConvertMetaData(metadata); |
162 | |
163 | if (metadata->executor == runtime::kTvmExecutorAot && runtime->name == relay::kTvmRuntimeCpp) { |
164 | if (target->kind->name == "c" ) { |
165 | auto metadata_module = CreateCSourceCppMetadataModule(runtime_metadata); |
166 | metadata_module->Import(target_module); |
167 | target_module = metadata_module; |
168 | #ifdef TVM_LLVM_VERSION // defining TVM_LLVM_VERSION indicates TVM was compiled with USE_LLVM ON. |
169 | } else if (target->kind->name == "llvm" ) { |
170 | auto metadata_module = CreateLLVMCppMetadataModule(runtime_metadata, target, runtime); |
171 | metadata_module->Import(target_module); |
172 | target_module = metadata_module; |
173 | #endif // TVM_LLVM_VERSION |
174 | } else { |
175 | CHECK(false) << "Don't know how to create MetadataModule for target type " << target->str(); |
176 | } |
177 | } |
178 | } |
179 | |
180 | return target_module; |
181 | } |
182 | |
183 | /*! |
184 | * \brief Create a metadata module wrapper. The helper is used by different |
185 | * codegens, such as graph executor codegen and the vm compiler. |
186 | * |
187 | * \param params The metadata for initialization of all modules. |
188 | * \param target_module the internal module that is compiled by tvm. |
189 | * \param ext_modules The external modules that needs to be imported inside the metadata |
190 | * module(s). |
191 | * \param target The target that all the modules are compiled for |
192 | * \return The created metadata module that manages initialization of metadata. |
193 | */ |
194 | runtime::Module CreateMetadataModule( |
195 | const std::unordered_map<std::string, runtime::NDArray>& const_var_ndarray, |
196 | tvm::runtime::Module target_module, const Array<runtime::Module>& ext_modules, Target target, |
197 | tvm::relay::Runtime runtime, tvm::relay::Executor executor, |
198 | relay::backend::ExecutorCodegenMetadata metadata) { |
199 | // Here we split modules into two groups: |
200 | // 1. Those modules which can be exported to C-runtime. These are DSO-exportable |
201 | // (i.e. llvm or c) modules which return nothing from get_const_vars(). |
202 | // 2. Other modules. |
203 | Array<runtime::Module> crt_exportable_modules; |
204 | Array<runtime::Module> non_crt_exportable_modules; |
205 | |
206 | bool is_targeting_crt = runtime->name == "crt" ; |
207 | |
208 | // Wrap all submodules in the initialization wrapper. |
209 | std::unordered_map<std::string, std::vector<std::string>> const_vars_by_symbol; |
210 | for (tvm::runtime::Module mod : ext_modules) { |
211 | auto pf_sym = mod.GetFunction("get_symbol" ); |
212 | auto pf_var = mod.GetFunction("get_const_vars" ); |
213 | std::vector<std::string> symbol_const_vars; |
214 | if (pf_sym != nullptr && pf_var != nullptr) { |
215 | String symbol = pf_sym(); |
216 | Array<String> variables = pf_var(); |
217 | for (size_t i = 0; i < variables.size(); i++) { |
218 | symbol_const_vars.push_back(variables[i].operator std::string()); |
219 | } |
220 | ICHECK_EQ(const_vars_by_symbol.count(symbol), 0U) << "Found duplicated symbol: " << symbol; |
221 | const_vars_by_symbol[symbol] = symbol_const_vars; |
222 | } |
223 | // We only need loading of serialized constant data |
224 | // if there are constants present and required by the |
225 | // runtime module to be initialized by the binary |
226 | // metadata module. If not rest of the modules are |
227 | // wrapped in c-source metadata module. |
228 | |
229 | // TODO(@manupa-arm) : we should be able to use csource_metadata |
230 | // if the variables are empty when all the runtime modules implement get_func_names |
231 | if (symbol_const_vars.empty() && is_targeting_crt && mod->IsDSOExportable() && |
232 | (target->kind->name == "c" || target->kind->name == "llvm" )) { |
233 | crt_exportable_modules.push_back(mod); |
234 | } else { |
235 | non_crt_exportable_modules.push_back(mod); |
236 | } |
237 | } |
238 | |
239 | if (is_targeting_crt) { |
240 | return CreateCrtMetadataModule(target_module, target, runtime, executor, metadata, |
241 | non_crt_exportable_modules, crt_exportable_modules, |
242 | const_var_ndarray); |
243 | } else { |
244 | return CreateCppMetadataModule(target_module, target, runtime, metadata, const_vars_by_symbol, |
245 | non_crt_exportable_modules, crt_exportable_modules, |
246 | const_var_ndarray); |
247 | } |
248 | } |
249 | |
250 | } // namespace codegen |
251 | |
252 | } // namespace tvm |
253 | |