1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file codegen.cc |
22 | * \brief Common utilities to generated C style code. |
23 | */ |
24 | #include <dmlc/memory_io.h> |
25 | #include <tvm/ir/module.h> |
26 | #include <tvm/runtime/c_runtime_api.h> |
27 | #include <tvm/runtime/module.h> |
28 | #include <tvm/runtime/registry.h> |
29 | #include <tvm/target/codegen.h> |
30 | #include <tvm/target/target.h> |
31 | #include <tvm/tir/function.h> |
32 | #include <tvm/tir/transform.h> |
33 | |
34 | #include <cstdint> |
35 | #include <cstring> |
36 | #include <sstream> |
37 | #include <unordered_set> |
38 | #include <vector> |
39 | |
40 | namespace tvm { |
41 | namespace codegen { |
42 | |
43 | runtime::Module Build(IRModule mod, Target target) { |
44 | if (transform::PassContext::Current() |
45 | ->GetConfig<Bool>("tir.disable_assert" , Bool(false)) |
46 | .value()) { |
47 | mod = tir::transform::SkipAssert()(mod); |
48 | } |
49 | |
50 | auto target_attr_map = tvm::TargetKind::GetAttrMap<FTVMTIRToRuntime>("TIRToRuntime" ); |
51 | if (target_attr_map.count(target->kind)) { |
52 | return target_attr_map[target->kind](mod, target); |
53 | } |
54 | |
55 | // the build function. |
56 | std::string build_f_name = "target.build." + target->kind->name; |
57 | const PackedFunc* bf = runtime::Registry::Get(build_f_name); |
58 | ICHECK(bf != nullptr) << build_f_name << " is not enabled" ; |
59 | return (*bf)(mod, target); |
60 | } |
61 | |
62 | /*! \brief Helper class to serialize module */ |
63 | class ModuleSerializer { |
64 | public: |
65 | explicit ModuleSerializer(runtime::Module mod) : mod_(mod) { Init(); } |
66 | |
67 | void SerializeModule(dmlc::Stream* stream) { |
68 | // Only have one DSO module and it is in the root, then |
69 | // we will not produce import_tree_. |
70 | bool has_import_tree = true; |
71 | if (mod_->IsDSOExportable() && mod_->imports().empty()) { |
72 | has_import_tree = false; |
73 | } |
74 | uint64_t sz = 0; |
75 | if (has_import_tree) { |
76 | // we will append one key for _import_tree |
77 | // The layout is the same as before: binary_size, key, logic, key, logic... |
78 | sz = mod_group_vec_.size() + 1; |
79 | } else { |
80 | // Keep the old behaviour |
81 | sz = mod_->imports().size(); |
82 | } |
83 | stream->Write(sz); |
84 | |
85 | for (const auto& group : mod_group_vec_) { |
86 | ICHECK_NE(group.size(), 0) << "Every allocated group must have at least one module" ; |
87 | if (!group[0]->IsDSOExportable()) { |
88 | ICHECK_EQ(group.size(), 1U) << "Non DSO module is never merged" ; |
89 | std::string mod_type_key = group[0]->type_key(); |
90 | stream->Write(mod_type_key); |
91 | group[0]->SaveToBinary(stream); |
92 | } else { |
93 | // DSOExportable: do not need binary |
94 | if (has_import_tree) { |
95 | std::string mod_type_key = "_lib" ; |
96 | stream->Write(mod_type_key); |
97 | } |
98 | } |
99 | } |
100 | |
101 | // Write _import_tree key if we have |
102 | if (has_import_tree) { |
103 | std::string import_key = "_import_tree" ; |
104 | stream->Write(import_key); |
105 | stream->Write(import_tree_row_ptr_); |
106 | stream->Write(import_tree_child_indices_); |
107 | } |
108 | } |
109 | |
110 | private: |
111 | void Init() { |
112 | CreateModuleIndex(); |
113 | CreateImportTree(); |
114 | } |
115 | |
116 | // invariance: root module is always at location 0. |
117 | // The module order is collected via DFS |
118 | // This function merges all the DSO exportable module into |
119 | // a single one as this is also what happens in the final hierachy |
120 | void CreateModuleIndex() { |
121 | std::unordered_set<const runtime::ModuleNode*> visited{mod_.operator->()}; |
122 | std::vector<runtime::ModuleNode*> stack{mod_.operator->()}; |
123 | uint64_t module_index = 0; |
124 | |
125 | auto fpush_imports_to_stack = [&](runtime::ModuleNode* node) { |
126 | for (runtime::Module m : node->imports()) { |
127 | runtime::ModuleNode* next = m.operator->(); |
128 | if (visited.count(next) == 0) { |
129 | visited.insert(next); |
130 | stack.push_back(next); |
131 | } |
132 | } |
133 | }; |
134 | |
135 | std::vector<runtime::ModuleNode*> dso_exportable_boundary; |
136 | |
137 | // Create module index that merges all dso module into a single group. |
138 | // |
139 | // Do a two phase visit, to ensure dso module's index |
140 | // is always bigger than a parent of any dso module |
141 | // and smaller than children of any dso module. |
142 | // |
143 | // Error will be raised in CreateImportTree |
144 | // if merging dso module causes a cycle in the import tree |
145 | |
146 | // Phase 0: only expand non-dso-module and record the boundary. |
147 | while (!stack.empty()) { |
148 | runtime::ModuleNode* n = stack.back(); |
149 | stack.pop_back(); |
150 | if (n->IsDSOExportable()) { |
151 | // do not recursively expand dso modules |
152 | // we will expand in phase 1 |
153 | dso_exportable_boundary.emplace_back(n); |
154 | } else { |
155 | // expand the non-dso modules |
156 | mod2index_[n] = module_index++; |
157 | mod_group_vec_.emplace_back(std::vector<runtime::ModuleNode*>({n})); |
158 | fpush_imports_to_stack(n); |
159 | } |
160 | } |
161 | if (dso_exportable_boundary.size() == 0) return; |
162 | |
163 | // create the slot for dso exportable modules |
164 | // This index is chosen so that all the DSO's parents are |
165 | // allocated before this index, and children will be allocated after |
166 | uint64_t dso_module_index = module_index++; |
167 | mod_group_vec_.emplace_back(std::vector<runtime::ModuleNode*>()); |
168 | |
169 | // restart visiting the stack using elements in dso exportable boundary |
170 | stack = std::move(dso_exportable_boundary); |
171 | |
172 | // Phase 1: expand the children of dso modules. |
173 | while (!stack.empty()) { |
174 | runtime::ModuleNode* n = stack.back(); |
175 | stack.pop_back(); |
176 | |
177 | if (n->IsDSOExportable()) { |
178 | mod_group_vec_[dso_module_index].emplace_back(n); |
179 | mod2index_[n] = dso_module_index; |
180 | } else { |
181 | mod2index_[n] = module_index++; |
182 | mod_group_vec_.emplace_back(std::vector<runtime::ModuleNode*>({n})); |
183 | } |
184 | fpush_imports_to_stack(n); |
185 | } |
186 | } |
187 | |
188 | void CreateImportTree() { |
189 | std::vector<int64_t> child_indices; |
190 | |
191 | for (size_t parent_index = 0; parent_index < mod_group_vec_.size(); ++parent_index) { |
192 | child_indices.clear(); |
193 | for (const auto* m : mod_group_vec_[parent_index]) { |
194 | for (runtime::Module im : m->imports()) { |
195 | uint64_t mod_index = mod2index_.at(im.operator->()); |
196 | // skip cycle when dso modules are merged together |
197 | if (mod_index != parent_index) { |
198 | child_indices.emplace_back(mod_index); |
199 | } |
200 | } |
201 | } |
202 | // sort and unique the merged indices |
203 | std::sort(child_indices.begin(), child_indices.end()); |
204 | auto unique_end = std::unique(child_indices.begin(), child_indices.end()); |
205 | |
206 | // Check cycles due to merging dso exportable modules. |
207 | if (child_indices.size() != 0) { |
208 | // The index is supposed to follow the topological order. |
209 | CHECK_LT(parent_index, child_indices[0]) |
210 | << "RuntimeError: Cannot export due to multiple dso-exportables " |
211 | << "that cannot be merged without creating a cycle in the import tree. " |
212 | << "Related module keys: parent=" << mod_group_vec_[parent_index][0]->type_key() |
213 | << ", child=" << mod_group_vec_[child_indices[0]][0]->type_key(); |
214 | } |
215 | // insert the child indices |
216 | import_tree_child_indices_.insert(import_tree_child_indices_.end(), child_indices.begin(), |
217 | unique_end); |
218 | import_tree_row_ptr_.push_back(import_tree_child_indices_.size()); |
219 | } |
220 | } |
221 | |
222 | runtime::Module mod_; |
223 | // construct module to index |
224 | std::unordered_map<runtime::ModuleNode*, size_t> mod2index_; |
225 | // index -> module group |
226 | std::vector<std::vector<runtime::ModuleNode*>> mod_group_vec_; |
227 | std::vector<uint64_t> import_tree_row_ptr_{0}; |
228 | std::vector<uint64_t> import_tree_child_indices_; |
229 | }; |
230 | |
231 | namespace { |
232 | std::string SerializeModule(const runtime::Module& mod) { |
233 | std::string bin; |
234 | dmlc::MemoryStringStream ms(&bin); |
235 | dmlc::Stream* stream = &ms; |
236 | |
237 | ModuleSerializer module_serializer(mod); |
238 | module_serializer.SerializeModule(stream); |
239 | |
240 | return bin; |
241 | } |
242 | } // namespace |
243 | |
244 | std::string PackImportsToC(const runtime::Module& mod, bool system_lib) { |
245 | std::string bin = SerializeModule(mod); |
246 | |
247 | // translate to C program |
248 | std::ostringstream os; |
249 | os << "#ifdef _WIN32\n" |
250 | << "#define TVM_EXPORT __declspec(dllexport)\n" |
251 | << "#else\n" |
252 | << "#define TVM_EXPORT\n" |
253 | << "#endif\n" ; |
254 | os << "#ifdef __cplusplus\n" |
255 | << "extern \"C\" {\n" |
256 | << "#endif\n" ; |
257 | os << "TVM_EXPORT extern const unsigned char " << runtime::symbol::tvm_dev_mblob << "[];\n" ; |
258 | uint64_t nbytes = bin.length(); |
259 | os << "const unsigned char " << runtime::symbol::tvm_dev_mblob << "[" |
260 | << bin.length() + sizeof(nbytes) << "] = {\n " ; |
261 | os << std::hex; |
262 | size_t nunit = 80 / 4; |
263 | for (size_t i = 0; i < sizeof(nbytes); ++i) { |
264 | // sperators |
265 | if (i != 0) { |
266 | os << "," ; |
267 | } |
268 | os << "0x" << ((nbytes >> (i * 8)) & 0xffUL); |
269 | } |
270 | for (size_t i = 0; i < bin.length(); ++i) { |
271 | // sperators |
272 | if ((i + sizeof(nbytes)) % nunit == 0) { |
273 | os << ",\n " ; |
274 | } else { |
275 | os << "," ; |
276 | } |
277 | int c = bin[i]; |
278 | os << "0x" << (c & 0xff); |
279 | } |
280 | os << "\n};\n" ; |
281 | if (system_lib) { |
282 | os << "extern int TVMBackendRegisterSystemLibSymbol(const char*, void*);\n" ; |
283 | os << "static int " << runtime::symbol::tvm_dev_mblob << "_reg_ = " |
284 | << "TVMBackendRegisterSystemLibSymbol(\"" << runtime::symbol::tvm_dev_mblob << "\", (void*)" |
285 | << runtime::symbol::tvm_dev_mblob << ");\n" ; |
286 | } |
287 | os << "#ifdef __cplusplus\n" |
288 | << "}\n" |
289 | << "#endif\n" ; |
290 | return os.str(); |
291 | } |
292 | |
293 | runtime::Module PackImportsToLLVM(const runtime::Module& mod, bool system_lib, |
294 | const std::string& llvm_target_string) { |
295 | std::string bin = SerializeModule(mod); |
296 | |
297 | uint64_t nbytes = bin.length(); |
298 | std::string ; |
299 | for (size_t i = 0; i < sizeof(nbytes); ++i) { |
300 | header.push_back(((nbytes >> (i * 8)) & 0xffUL)); |
301 | } |
302 | std::string blob = header + bin; |
303 | TVMByteArray blob_byte_array; |
304 | blob_byte_array.size = blob.length(); |
305 | blob_byte_array.data = blob.data(); |
306 | |
307 | // Call codegen_blob to generate LLVM module |
308 | std::string codegen_f_name = "codegen.codegen_blob" ; |
309 | // the codegen function. |
310 | const PackedFunc* codegen_f = runtime::Registry::Get(codegen_f_name); |
311 | ICHECK(codegen_f != nullptr) << "codegen.codegen_blob is not presented." ; |
312 | return (*codegen_f)(blob_byte_array, system_lib, llvm_target_string); |
313 | } |
314 | |
315 | TVM_REGISTER_GLOBAL("target.Build" ).set_body_typed(Build); |
316 | |
317 | // Export two auxiliary function to the runtime namespace. |
318 | TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToC" ).set_body_typed(PackImportsToC); |
319 | |
320 | TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToLLVM" ).set_body_typed(PackImportsToLLVM); |
321 | |
322 | } // namespace codegen |
323 | } // namespace tvm |
324 | |