1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file codegen.cc
22 * \brief Common utilities to generated C style code.
23 */
24#include <dmlc/memory_io.h>
25#include <tvm/ir/module.h>
26#include <tvm/runtime/c_runtime_api.h>
27#include <tvm/runtime/module.h>
28#include <tvm/runtime/registry.h>
29#include <tvm/target/codegen.h>
30#include <tvm/target/target.h>
31#include <tvm/tir/function.h>
32#include <tvm/tir/transform.h>
33
34#include <cstdint>
35#include <cstring>
36#include <sstream>
37#include <unordered_set>
38#include <vector>
39
40namespace tvm {
41namespace codegen {
42
43runtime::Module Build(IRModule mod, Target target) {
44 if (transform::PassContext::Current()
45 ->GetConfig<Bool>("tir.disable_assert", Bool(false))
46 .value()) {
47 mod = tir::transform::SkipAssert()(mod);
48 }
49
50 auto target_attr_map = tvm::TargetKind::GetAttrMap<FTVMTIRToRuntime>("TIRToRuntime");
51 if (target_attr_map.count(target->kind)) {
52 return target_attr_map[target->kind](mod, target);
53 }
54
55 // the build function.
56 std::string build_f_name = "target.build." + target->kind->name;
57 const PackedFunc* bf = runtime::Registry::Get(build_f_name);
58 ICHECK(bf != nullptr) << build_f_name << " is not enabled";
59 return (*bf)(mod, target);
60}
61
62/*! \brief Helper class to serialize module */
63class ModuleSerializer {
64 public:
65 explicit ModuleSerializer(runtime::Module mod) : mod_(mod) { Init(); }
66
67 void SerializeModule(dmlc::Stream* stream) {
68 // Only have one DSO module and it is in the root, then
69 // we will not produce import_tree_.
70 bool has_import_tree = true;
71 if (mod_->IsDSOExportable() && mod_->imports().empty()) {
72 has_import_tree = false;
73 }
74 uint64_t sz = 0;
75 if (has_import_tree) {
76 // we will append one key for _import_tree
77 // The layout is the same as before: binary_size, key, logic, key, logic...
78 sz = mod_group_vec_.size() + 1;
79 } else {
80 // Keep the old behaviour
81 sz = mod_->imports().size();
82 }
83 stream->Write(sz);
84
85 for (const auto& group : mod_group_vec_) {
86 ICHECK_NE(group.size(), 0) << "Every allocated group must have at least one module";
87 if (!group[0]->IsDSOExportable()) {
88 ICHECK_EQ(group.size(), 1U) << "Non DSO module is never merged";
89 std::string mod_type_key = group[0]->type_key();
90 stream->Write(mod_type_key);
91 group[0]->SaveToBinary(stream);
92 } else {
93 // DSOExportable: do not need binary
94 if (has_import_tree) {
95 std::string mod_type_key = "_lib";
96 stream->Write(mod_type_key);
97 }
98 }
99 }
100
101 // Write _import_tree key if we have
102 if (has_import_tree) {
103 std::string import_key = "_import_tree";
104 stream->Write(import_key);
105 stream->Write(import_tree_row_ptr_);
106 stream->Write(import_tree_child_indices_);
107 }
108 }
109
110 private:
111 void Init() {
112 CreateModuleIndex();
113 CreateImportTree();
114 }
115
116 // invariance: root module is always at location 0.
117 // The module order is collected via DFS
118 // This function merges all the DSO exportable module into
119 // a single one as this is also what happens in the final hierachy
120 void CreateModuleIndex() {
121 std::unordered_set<const runtime::ModuleNode*> visited{mod_.operator->()};
122 std::vector<runtime::ModuleNode*> stack{mod_.operator->()};
123 uint64_t module_index = 0;
124
125 auto fpush_imports_to_stack = [&](runtime::ModuleNode* node) {
126 for (runtime::Module m : node->imports()) {
127 runtime::ModuleNode* next = m.operator->();
128 if (visited.count(next) == 0) {
129 visited.insert(next);
130 stack.push_back(next);
131 }
132 }
133 };
134
135 std::vector<runtime::ModuleNode*> dso_exportable_boundary;
136
137 // Create module index that merges all dso module into a single group.
138 //
139 // Do a two phase visit, to ensure dso module's index
140 // is always bigger than a parent of any dso module
141 // and smaller than children of any dso module.
142 //
143 // Error will be raised in CreateImportTree
144 // if merging dso module causes a cycle in the import tree
145
146 // Phase 0: only expand non-dso-module and record the boundary.
147 while (!stack.empty()) {
148 runtime::ModuleNode* n = stack.back();
149 stack.pop_back();
150 if (n->IsDSOExportable()) {
151 // do not recursively expand dso modules
152 // we will expand in phase 1
153 dso_exportable_boundary.emplace_back(n);
154 } else {
155 // expand the non-dso modules
156 mod2index_[n] = module_index++;
157 mod_group_vec_.emplace_back(std::vector<runtime::ModuleNode*>({n}));
158 fpush_imports_to_stack(n);
159 }
160 }
161 if (dso_exportable_boundary.size() == 0) return;
162
163 // create the slot for dso exportable modules
164 // This index is chosen so that all the DSO's parents are
165 // allocated before this index, and children will be allocated after
166 uint64_t dso_module_index = module_index++;
167 mod_group_vec_.emplace_back(std::vector<runtime::ModuleNode*>());
168
169 // restart visiting the stack using elements in dso exportable boundary
170 stack = std::move(dso_exportable_boundary);
171
172 // Phase 1: expand the children of dso modules.
173 while (!stack.empty()) {
174 runtime::ModuleNode* n = stack.back();
175 stack.pop_back();
176
177 if (n->IsDSOExportable()) {
178 mod_group_vec_[dso_module_index].emplace_back(n);
179 mod2index_[n] = dso_module_index;
180 } else {
181 mod2index_[n] = module_index++;
182 mod_group_vec_.emplace_back(std::vector<runtime::ModuleNode*>({n}));
183 }
184 fpush_imports_to_stack(n);
185 }
186 }
187
188 void CreateImportTree() {
189 std::vector<int64_t> child_indices;
190
191 for (size_t parent_index = 0; parent_index < mod_group_vec_.size(); ++parent_index) {
192 child_indices.clear();
193 for (const auto* m : mod_group_vec_[parent_index]) {
194 for (runtime::Module im : m->imports()) {
195 uint64_t mod_index = mod2index_.at(im.operator->());
196 // skip cycle when dso modules are merged together
197 if (mod_index != parent_index) {
198 child_indices.emplace_back(mod_index);
199 }
200 }
201 }
202 // sort and unique the merged indices
203 std::sort(child_indices.begin(), child_indices.end());
204 auto unique_end = std::unique(child_indices.begin(), child_indices.end());
205
206 // Check cycles due to merging dso exportable modules.
207 if (child_indices.size() != 0) {
208 // The index is supposed to follow the topological order.
209 CHECK_LT(parent_index, child_indices[0])
210 << "RuntimeError: Cannot export due to multiple dso-exportables "
211 << "that cannot be merged without creating a cycle in the import tree. "
212 << "Related module keys: parent=" << mod_group_vec_[parent_index][0]->type_key()
213 << ", child=" << mod_group_vec_[child_indices[0]][0]->type_key();
214 }
215 // insert the child indices
216 import_tree_child_indices_.insert(import_tree_child_indices_.end(), child_indices.begin(),
217 unique_end);
218 import_tree_row_ptr_.push_back(import_tree_child_indices_.size());
219 }
220 }
221
222 runtime::Module mod_;
223 // construct module to index
224 std::unordered_map<runtime::ModuleNode*, size_t> mod2index_;
225 // index -> module group
226 std::vector<std::vector<runtime::ModuleNode*>> mod_group_vec_;
227 std::vector<uint64_t> import_tree_row_ptr_{0};
228 std::vector<uint64_t> import_tree_child_indices_;
229};
230
231namespace {
232std::string SerializeModule(const runtime::Module& mod) {
233 std::string bin;
234 dmlc::MemoryStringStream ms(&bin);
235 dmlc::Stream* stream = &ms;
236
237 ModuleSerializer module_serializer(mod);
238 module_serializer.SerializeModule(stream);
239
240 return bin;
241}
242} // namespace
243
244std::string PackImportsToC(const runtime::Module& mod, bool system_lib) {
245 std::string bin = SerializeModule(mod);
246
247 // translate to C program
248 std::ostringstream os;
249 os << "#ifdef _WIN32\n"
250 << "#define TVM_EXPORT __declspec(dllexport)\n"
251 << "#else\n"
252 << "#define TVM_EXPORT\n"
253 << "#endif\n";
254 os << "#ifdef __cplusplus\n"
255 << "extern \"C\" {\n"
256 << "#endif\n";
257 os << "TVM_EXPORT extern const unsigned char " << runtime::symbol::tvm_dev_mblob << "[];\n";
258 uint64_t nbytes = bin.length();
259 os << "const unsigned char " << runtime::symbol::tvm_dev_mblob << "["
260 << bin.length() + sizeof(nbytes) << "] = {\n ";
261 os << std::hex;
262 size_t nunit = 80 / 4;
263 for (size_t i = 0; i < sizeof(nbytes); ++i) {
264 // sperators
265 if (i != 0) {
266 os << ",";
267 }
268 os << "0x" << ((nbytes >> (i * 8)) & 0xffUL);
269 }
270 for (size_t i = 0; i < bin.length(); ++i) {
271 // sperators
272 if ((i + sizeof(nbytes)) % nunit == 0) {
273 os << ",\n ";
274 } else {
275 os << ",";
276 }
277 int c = bin[i];
278 os << "0x" << (c & 0xff);
279 }
280 os << "\n};\n";
281 if (system_lib) {
282 os << "extern int TVMBackendRegisterSystemLibSymbol(const char*, void*);\n";
283 os << "static int " << runtime::symbol::tvm_dev_mblob << "_reg_ = "
284 << "TVMBackendRegisterSystemLibSymbol(\"" << runtime::symbol::tvm_dev_mblob << "\", (void*)"
285 << runtime::symbol::tvm_dev_mblob << ");\n";
286 }
287 os << "#ifdef __cplusplus\n"
288 << "}\n"
289 << "#endif\n";
290 return os.str();
291}
292
293runtime::Module PackImportsToLLVM(const runtime::Module& mod, bool system_lib,
294 const std::string& llvm_target_string) {
295 std::string bin = SerializeModule(mod);
296
297 uint64_t nbytes = bin.length();
298 std::string header;
299 for (size_t i = 0; i < sizeof(nbytes); ++i) {
300 header.push_back(((nbytes >> (i * 8)) & 0xffUL));
301 }
302 std::string blob = header + bin;
303 TVMByteArray blob_byte_array;
304 blob_byte_array.size = blob.length();
305 blob_byte_array.data = blob.data();
306
307 // Call codegen_blob to generate LLVM module
308 std::string codegen_f_name = "codegen.codegen_blob";
309 // the codegen function.
310 const PackedFunc* codegen_f = runtime::Registry::Get(codegen_f_name);
311 ICHECK(codegen_f != nullptr) << "codegen.codegen_blob is not presented.";
312 return (*codegen_f)(blob_byte_array, system_lib, llvm_target_string);
313}
314
315TVM_REGISTER_GLOBAL("target.Build").set_body_typed(Build);
316
317// Export two auxiliary function to the runtime namespace.
318TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToC").set_body_typed(PackImportsToC);
319
320TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToLLVM").set_body_typed(PackImportsToLLVM);
321
322} // namespace codegen
323} // namespace tvm
324