1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file source_module.cc |
22 | * \brief Source code module, only for viewing |
23 | */ |
24 | #include "source_module.h" |
25 | |
26 | #include <dmlc/memory_io.h> |
27 | #include <tvm/runtime/metadata.h> |
28 | #include <tvm/runtime/module.h> |
29 | #include <tvm/runtime/name_transforms.h> |
30 | #include <tvm/runtime/ndarray.h> |
31 | #include <tvm/runtime/packed_func.h> |
32 | #include <tvm/runtime/registry.h> |
33 | |
34 | #include <algorithm> |
35 | #include <functional> |
36 | #include <numeric> |
37 | #include <string> |
38 | #include <unordered_map> |
39 | #include <unordered_set> |
40 | #include <utility> |
41 | #include <vector> |
42 | |
43 | #include "../../relay/backend/name_transforms.h" |
44 | #include "../../runtime/file_utils.h" |
45 | #include "../../support/str_escape.h" |
46 | #include "../func_registry_generator.h" |
47 | #include "../metadata.h" |
48 | #include "../metadata_utils.h" |
49 | #include "codegen_params.h" |
50 | #include "codegen_source_base.h" |
51 | #include "tvm/relay/executor.h" |
52 | |
53 | namespace tvm { |
54 | namespace codegen { |
55 | |
56 | using runtime::PackedFunc; |
57 | using runtime::TVMArgs; |
58 | using runtime::TVMRetValue; |
59 | |
60 | using runtime::FunctionInfo; |
61 | using runtime::GetFileFormat; |
62 | using runtime::GetMetaFilePath; |
63 | using runtime::SaveBinaryToFile; |
64 | |
65 | // Simulator function |
66 | class SourceModuleNode : public runtime::ModuleNode { |
67 | public: |
68 | SourceModuleNode(std::string code, std::string fmt) : code_(code), fmt_(fmt) {} |
69 | const char* type_key() const final { return "source" ; } |
70 | |
71 | PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final { |
72 | LOG(FATAL) << "Source module cannot execute, to get executable module" |
73 | << " build TVM with \'" << fmt_ << "\' runtime support" ; |
74 | return PackedFunc(); |
75 | } |
76 | |
77 | std::string GetSource(const std::string& format) final { return code_; } |
78 | |
79 | std::string GetFormat() { return fmt_; } |
80 | |
81 | protected: |
82 | std::string code_; |
83 | std::string fmt_; |
84 | }; |
85 | |
86 | runtime::Module SourceModuleCreate(std::string code, std::string fmt) { |
87 | auto n = make_object<SourceModuleNode>(code, fmt); |
88 | return runtime::Module(n); |
89 | } |
90 | |
91 | // Simulator function |
92 | class CSourceModuleNode : public runtime::ModuleNode { |
93 | public: |
94 | CSourceModuleNode(const std::string& code, const std::string& fmt, |
95 | const Array<String>& func_names, const Array<String>& const_vars) |
96 | : code_(code), fmt_(fmt), const_vars_(const_vars), func_names_(func_names) {} |
97 | const char* type_key() const final { return "c" ; } |
98 | |
99 | PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final { |
100 | // Currently c-source module is used as demonstration purposes with binary metadata module |
101 | // that expects get_symbol interface. When c-source module is used as external module, it |
102 | // will only contain one function. However, when its used as an internal module (e.g., target |
103 | // "c") it can have many functions. |
104 | if (name == "get_symbol" ) { |
105 | return PackedFunc( |
106 | [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->func_names_[0]; }); |
107 | } else if (name == "get_const_vars" ) { |
108 | return PackedFunc( |
109 | [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->const_vars_; }); |
110 | } else if (name == "get_func_names" ) { |
111 | return PackedFunc( |
112 | [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->func_names_; }); |
113 | } else { |
114 | return PackedFunc(nullptr); |
115 | } |
116 | } |
117 | |
118 | std::string GetSource(const std::string& format) final { return code_; } |
119 | |
120 | std::string GetFormat() { return fmt_; } |
121 | |
122 | void SaveToFile(const std::string& file_name, const std::string& format) final { |
123 | std::string fmt = GetFileFormat(file_name, format); |
124 | std::string meta_file = GetMetaFilePath(file_name); |
125 | if (fmt == "c" || fmt == "cc" || fmt == "cpp" || fmt == "cu" ) { |
126 | ICHECK_NE(code_.length(), 0); |
127 | SaveBinaryToFile(file_name, code_); |
128 | } else { |
129 | ICHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; |
130 | } |
131 | } |
132 | |
133 | bool IsDSOExportable() const final { return true; } |
134 | |
135 | bool ImplementsFunction(const String& name, bool query_imports) final { |
136 | return std::find(func_names_.begin(), func_names_.end(), name) != func_names_.end(); |
137 | } |
138 | |
139 | protected: |
140 | std::string code_; |
141 | std::string fmt_; |
142 | Array<String> const_vars_; |
143 | Array<String> func_names_; |
144 | }; |
145 | |
146 | runtime::Module CSourceModuleCreate(const String& code, const String& fmt, |
147 | const Array<String>& func_names, |
148 | const Array<String>& const_vars) { |
149 | auto n = make_object<CSourceModuleNode>(code.operator std::string(), fmt.operator std::string(), |
150 | func_names, const_vars); |
151 | return runtime::Module(n); |
152 | } |
153 | |
154 | /*! |
155 | * \brief A concrete class to get access to base methods of CodegenSourceBase. |
156 | * |
157 | * This class exist to get access to methods of CodegenSourceBase without duplicating |
158 | * them. Therefore, keeping alignment with how codegen and source_module here generates |
159 | * code. |
160 | */ |
161 | class ConcreteCodegenSourceBase : public CodeGenSourceBase { |
162 | /*! |
163 | * \brief Do nothing as this class exist to get access to methods of CodeGenSourceBase |
164 | */ |
165 | void PrintSSAAssign(const std::string& target, const std::string& src, DataType t) final { |
166 | return; |
167 | } |
168 | }; |
169 | |
170 | class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { |
171 | public: |
172 | CSourceCrtMetadataModuleNode(const Array<String>& func_names, const std::string& fmt, |
173 | Target target, relay::Runtime runtime, |
174 | relay::backend::ExecutorCodegenMetadata metadata) |
175 | : fmt_(fmt), |
176 | func_names_(func_names), |
177 | target_(target), |
178 | runtime_(runtime), |
179 | metadata_(metadata) { |
180 | CreateSource(); |
181 | } |
182 | const char* type_key() const final { return "c" ; } |
183 | |
184 | std::string GetSource(const std::string& format) final { return code_.str(); } |
185 | |
186 | std::string GetFormat() { return fmt_; } |
187 | PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final { |
188 | return PackedFunc(); |
189 | } |
190 | |
191 | void SaveToFile(const std::string& file_name, const std::string& format) final { |
192 | std::string fmt = GetFileFormat(file_name, format); |
193 | std::string meta_file = GetMetaFilePath(file_name); |
194 | if (fmt == "c" || fmt == "cc" || fmt == "cpp" ) { |
195 | auto code_str = code_.str(); |
196 | ICHECK_NE(code_str.length(), 0); |
197 | SaveBinaryToFile(file_name, code_str); |
198 | } else { |
199 | ICHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; |
200 | } |
201 | } |
202 | |
203 | bool IsDSOExportable() const final { return true; } |
204 | |
205 | bool ImplementsFunction(const String& name, bool query_imports) final { |
206 | return std::find(func_names_.begin(), func_names_.end(), name) != func_names_.end(); |
207 | } |
208 | |
209 | protected: |
210 | std::stringstream code_; |
211 | std::string fmt_; |
212 | Array<String> func_names_; |
213 | Target target_; |
214 | relay::Runtime runtime_; |
215 | relay::backend::ExecutorCodegenMetadata metadata_; |
216 | ConcreteCodegenSourceBase codegen_c_base_; |
217 | |
218 | void CreateFuncRegistry() { |
219 | code_ << "#include <tvm/runtime/crt/module.h>\n" ; |
220 | for (const auto& fname : func_names_) { |
221 | code_ << "#ifdef __cplusplus\n" ; |
222 | code_ << "extern \"C\"\n" ; |
223 | code_ << "#endif\n" ; |
224 | code_ << "TVM_DLL int32_t " << fname.data(); |
225 | code_ << "(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, " |
226 | "int* out_type_code, void* resource_handle);\n" ; |
227 | } |
228 | code_ << "static TVMBackendPackedCFunc _tvm_func_array[] = {\n" ; |
229 | for (auto f : func_names_) { |
230 | code_ << " (TVMBackendPackedCFunc)" << f << ",\n" ; |
231 | } |
232 | code_ << "};\n" ; |
233 | auto registry = target::GenerateFuncRegistryNames(func_names_); |
234 | code_ << "static const TVMFuncRegistry _tvm_func_registry = {\n" |
235 | << " \"" << ::tvm::support::StrEscape(registry.data(), registry.size(), true) << "\"," |
236 | << " _tvm_func_array,\n" |
237 | << "};\n" ; |
238 | } |
239 | |
240 | void GenerateCrtSystemLib() { |
241 | code_ << "static const TVMModule _tvm_system_lib = {\n" |
242 | << " &_tvm_func_registry,\n" |
243 | << "};\n" |
244 | << "const TVMModule* TVMSystemLibEntryPoint(void) {\n" |
245 | << " return &_tvm_system_lib;\n" |
246 | << "}\n" ; |
247 | } |
248 | |
249 | String GenerateDLTensorStructWrapper(String reference_arg) { |
250 | code_ << "DLTensor " << reference_arg << "_dltensor = {\n" ; |
251 | code_ << ".data = &" << reference_arg << "\n" ; |
252 | code_ << "};\n" ; |
253 | code_ << "TVMValue " << reference_arg << "_tvm_value = {\n" ; |
254 | code_ << ".v_handle = &" << reference_arg << "_dltensor\n" ; |
255 | code_ << "};\n" ; |
256 | return reference_arg + "_tvm_value" ; |
257 | } |
258 | |
259 | void GenerateInternalBuffers() { |
260 | if (metadata_->pool_inputs.defined()) { |
261 | for (const auto& kv : metadata_->pool_inputs.value()) { |
262 | tir::usmp::AllocatedPoolInfo allocated_pool_info = kv.second; |
263 | if (allocated_pool_info->pool_info->is_internal) { |
264 | if (const auto* pool_info = allocated_pool_info->pool_info.as<ConstantPoolInfoNode>()) { |
265 | GenerateConstantBuffer(pool_info, allocated_pool_info->allocated_size->value); |
266 | } else { |
267 | GenerateWorkspaceBuffer(allocated_pool_info->pool_info.as<WorkspacePoolInfoNode>(), |
268 | allocated_pool_info->allocated_size->value); |
269 | } |
270 | } |
271 | } |
272 | } |
273 | } |
274 | |
275 | void GenerateIOWorkspaceMapFunction(const std::string& struct_type, |
276 | const std::string& function_name, |
277 | const Array<String>& tensor_names) { |
278 | std::string map_function = runtime::get_name_mangled(metadata_->mod_name, function_name); |
279 | code_ << "struct " << struct_type << " " << map_function << "(\n" ; |
280 | std::string pools_struct = runtime::get_name_mangled(metadata_->mod_name, "workspace_pools" ); |
281 | code_ << " struct " << pools_struct << "* workspace_pools\n" ; |
282 | code_ << "\n){\n" ; |
283 | code_ << "struct " << struct_type << " ret = {\n" ; |
284 | for (const String& name : tensor_names) { |
285 | tir::usmp::PoolAllocation pool_allocation = metadata_->io_pool_allocations[name]; |
286 | code_ << "\t." << name << " = " |
287 | << "&((uint8_t*)workspace_pools->" << pool_allocation->pool_info->pool_name << ")[" |
288 | << pool_allocation->byte_offset << "],\n" ; |
289 | } |
290 | code_ << "};\n" ; |
291 | code_ << "return ret;\n" ; |
292 | code_ << "}\n\n" ; |
293 | } |
294 | |
295 | void GenerateConstantBuffer(const ConstantPoolInfoNode* pool_info, size_t allocated_size) { |
296 | size_t ord = 0; |
297 | if (pool_info->constant_info_array.size() > 0) { |
298 | // Pool is RO, form an initialized struct |
299 | code_ << "__attribute__((section(\".rodata.tvm\"), " ; |
300 | code_ << "))\n" ; |
301 | code_ << "static struct " << pool_info->pool_name << " {\n" ; |
302 | // emit struct field names |
303 | std::vector<ConstantInfo> const_info_vec(pool_info->constant_info_array.begin(), |
304 | pool_info->constant_info_array.end()); |
305 | std::sort(const_info_vec.begin(), const_info_vec.end(), |
306 | [](const ConstantInfo& a, const ConstantInfo& b) { |
307 | return a->byte_offset->value < b->byte_offset->value; |
308 | }); |
309 | for (const auto& const_info : const_info_vec) { |
310 | const auto& data = const_info->data; |
311 | const auto& offs = const_info->byte_offset; |
312 | int64_t num_elements = std::accumulate(data.Shape().begin(), data.Shape().end(), 1, |
313 | std::multiplies<int64_t>()); |
314 | code_ << " " ; |
315 | codegen_c_base_.PrintType(data.DataType(), code_); |
316 | code_ << " " << const_info->name_hint << "[" << num_elements << "] __attribute__((" |
317 | << (ord++ ? "packed, " : "" ) << "aligned(" << metadata_->constant_alignment << ")));" ; |
318 | code_ << " // " << num_elements * data.DataType().bytes() |
319 | << " bytes, aligned offset: " << offs << "\n" ; |
320 | } |
321 | code_ << "} " << pool_info->pool_name << " = {\n" ; |
322 | |
323 | // emit struct field initialization data |
324 | for (const auto& const_info : const_info_vec) { |
325 | code_ << " ." << const_info->name_hint << " = {\n" ; |
326 | codegen::NDArrayDataToC(const_info->data, 4, code_); |
327 | code_ << " },\n" ; |
328 | } |
329 | code_ << "};" ; |
330 | code_ << "// of total size " << allocated_size << " bytes\n" ; |
331 | } else { |
332 | LOG(FATAL) << "No constant data in constant pool found " << GetRef<ObjectRef>(pool_info); |
333 | } |
334 | } |
335 | |
336 | void GenerateWorkspaceBuffer(const WorkspacePoolInfoNode* pool_info, size_t allocated_size) { |
337 | code_ << "__attribute__((section(\".bss.noinit.tvm\"), " ; |
338 | code_ << "aligned(" << metadata_->workspace_alignment << ")))\n" ; |
339 | code_ << "static uint8_t " << pool_info->pool_name << "[" ; |
340 | code_ << allocated_size << "];\n" ; |
341 | } |
342 | |
343 | bool IsInternalWorkspaceBuffer(const tir::Var& pool_var) { |
344 | if (metadata_->pool_inputs.defined()) { |
345 | Map<tir::Var, tir::usmp::AllocatedPoolInfo> allocated_pool_infos = |
346 | metadata_->pool_inputs.value(); |
347 | if (allocated_pool_infos.find(pool_var) != allocated_pool_infos.end()) { |
348 | tir::usmp::AllocatedPoolInfo allocate_pool_info = allocated_pool_infos[pool_var]; |
349 | if (allocate_pool_info->pool_info->is_internal) { |
350 | return true; |
351 | } |
352 | } |
353 | } |
354 | return false; |
355 | } |
356 | |
357 | void GenerateEntrypointForUnpackedAPI(const std::string& entrypoint_name, |
358 | const std::string& run_func) { |
359 | code_ << "TVM_DLL int32_t " << run_func << "(" ; |
360 | |
361 | { |
362 | std::stringstream call_args_ss; |
363 | if (metadata_->io_pool_allocations.empty()) { |
364 | for (const tir::Var& input_var : metadata_->inputs) { |
365 | if (input_var->type_annotation.defined()) { |
366 | codegen_c_base_.PrintType(input_var->type_annotation, call_args_ss); |
367 | } else { |
368 | codegen_c_base_.PrintType(input_var.dtype(), call_args_ss); |
369 | } |
370 | call_args_ss << " " << input_var->name_hint << "," ; |
371 | } |
372 | for (unsigned int i = 0; i < metadata_->outputs.size(); ++i) { |
373 | call_args_ss << "void* output" << i << "," ; |
374 | } |
375 | } |
376 | for (const tir::Var& pool_var : metadata_->pools) { |
377 | if (pool_var->type_annotation.defined()) { |
378 | codegen_c_base_.PrintType(pool_var->type_annotation, call_args_ss); |
379 | } else { |
380 | codegen_c_base_.PrintType(pool_var.dtype(), call_args_ss); |
381 | } |
382 | call_args_ss << " " << pool_var->name_hint << "," ; |
383 | } |
384 | std::string call_args_str = call_args_ss.str(); |
385 | call_args_str.pop_back(); |
386 | code_ << call_args_str; |
387 | } |
388 | |
389 | code_ << ");\n" ; |
390 | code_ << "int32_t " << entrypoint_name; |
391 | code_ << "(void* args, void* type_code, int num_args, void* out_value, void* " |
392 | "out_type_code, void* resource_handle) {\n" ; |
393 | code_ << "return " << run_func << "(" ; |
394 | |
395 | { |
396 | std::stringstream call_args_ss; |
397 | if (metadata_->io_pool_allocations.empty()) { |
398 | for (unsigned int i = 0; i < metadata_->inputs.size(); ++i) { |
399 | call_args_ss << "((DLTensor*)(((TVMValue*)args)[" << i << "].v_handle))[0].data," ; |
400 | } |
401 | for (unsigned int i = 0; i < metadata_->outputs.size(); ++i) { |
402 | int j = metadata_->inputs.size() + i; |
403 | call_args_ss << "((DLTensor*)(((TVMValue*)args)[" << j << "].v_handle))[0].data," ; |
404 | } |
405 | } |
406 | for (const tir::Var& pool_var : metadata_->pools) { |
407 | if (IsInternalWorkspaceBuffer(pool_var)) { |
408 | call_args_ss << "&" << metadata_->pool_inputs.value()[pool_var]->pool_info->pool_name |
409 | << "," ; |
410 | } |
411 | } |
412 | std::string call_args_str = call_args_ss.str(); |
413 | call_args_str.pop_back(); |
414 | code_ << call_args_str; |
415 | code_ << ");\n" ; |
416 | code_ << "}\n" ; |
417 | } |
418 | } |
419 | |
420 | std::unordered_map<int, ObjectRef> GenerateRunFuncToEntryPointArgMap() { |
421 | std::unordered_map<int, ObjectRef> run_func_to_entry_point_args; |
422 | int entrypoint_arg_count = 0; |
423 | int run_func_arg_count = 0; |
424 | |
425 | if (metadata_->io_pool_allocations.empty()) { |
426 | for (unsigned int i = 0; i < metadata_->inputs.size(); i++) { |
427 | run_func_to_entry_point_args[run_func_arg_count] = Integer(entrypoint_arg_count); |
428 | entrypoint_arg_count++; |
429 | run_func_arg_count++; |
430 | } |
431 | for (unsigned int i = 0; i < metadata_->outputs.size(); i++) { |
432 | run_func_to_entry_point_args[run_func_arg_count] = Integer(entrypoint_arg_count); |
433 | entrypoint_arg_count++; |
434 | run_func_arg_count++; |
435 | } |
436 | } |
437 | for (const tir::Var& pool_var : metadata_->pools) { |
438 | if (IsInternalWorkspaceBuffer(pool_var)) { |
439 | tir::usmp::AllocatedPoolInfo allocated_pool_info = metadata_->pool_inputs.value()[pool_var]; |
440 | run_func_to_entry_point_args[run_func_arg_count] = |
441 | allocated_pool_info->pool_info->pool_name; |
442 | run_func_arg_count++; |
443 | } |
444 | } |
445 | return run_func_to_entry_point_args; |
446 | } |
447 | |
448 | void GenerateEntrypointForPackedAPI(const std::string& entrypoint_name, |
449 | const std::string& run_func) { |
450 | code_ << "TVM_DLL int32_t " << run_func; |
451 | code_ << "(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, int* " |
452 | "out_type_code, void* resource_handle);\n\n" ; |
453 | |
454 | code_ << "int32_t " << entrypoint_name; |
455 | code_ << "(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, int* " |
456 | "out_type_code, void* resource_handle) {\n" ; |
457 | |
458 | // We are creating a copy of the set of pointers |
459 | size_t number_of_io_tensors = metadata_->inputs.size() + metadata_->outputs.size() + |
460 | metadata_->pools.size() - metadata_->io_pool_allocations.size(); |
461 | code_ << "TVMValue tensors[" << number_of_io_tensors << "];\n" ; |
462 | |
463 | std::unordered_map<int, ObjectRef> run_func_to_entry_point_args = |
464 | GenerateRunFuncToEntryPointArgMap(); |
465 | for (unsigned int i = 0; i < number_of_io_tensors; i++) { |
466 | if (run_func_to_entry_point_args.find(i) != run_func_to_entry_point_args.end()) { |
467 | if (run_func_to_entry_point_args[i]->IsInstance<StringObj>()) { |
468 | String pool_name = Downcast<String>(run_func_to_entry_point_args[i]); |
469 | String pool_name_tvmv = GenerateDLTensorStructWrapper(pool_name); |
470 | code_ << "tensors[" << i << "] = " << pool_name_tvmv << ";\n" ; |
471 | } else { |
472 | code_ << "tensors[" << i << "] = ((TVMValue*)args)[" << run_func_to_entry_point_args[i] |
473 | << "];\n" ; |
474 | } |
475 | } |
476 | } |
477 | |
478 | code_ << "return " << run_func; |
479 | code_ << "((void*)tensors, type_code, num_args, out_value, out_type_code, resource_handle);\n" ; |
480 | code_ << "}\n" ; |
481 | } |
482 | |
483 | static int isNotAlnum(char c) { return !std::isalnum(c); } |
484 | |
485 | void GenerateCInterfaceEntrypoint(const std::string& entrypoint_name, const std::string& run_func, |
486 | const std::string& mod_name) { |
487 | code_ << "#include <" << mod_name << ".h>\n" ; |
488 | if (!metadata_->io_pool_allocations.empty()) { |
489 | const std::string input_struct_type = |
490 | runtime::get_name_mangled(metadata_->mod_name, "inputs" ); |
491 | Array<String> input_tensor_names; |
492 | for (const tir::Var& input_var : metadata_->inputs) { |
493 | input_tensor_names.push_back(input_var->name_hint); |
494 | } |
495 | GenerateIOWorkspaceMapFunction(input_struct_type, "map_inputs" , input_tensor_names); |
496 | const std::string output_struct_type = |
497 | runtime::get_name_mangled(metadata_->mod_name, "outputs" ); |
498 | GenerateIOWorkspaceMapFunction(output_struct_type, "map_outputs" , metadata_->outputs); |
499 | } |
500 | code_ << "TVM_DLL int32_t " << run_func << "(" ; |
501 | { |
502 | std::stringstream call_args_ss; |
503 | if (metadata_->io_pool_allocations.empty()) { |
504 | for (const tir::Var& input_var : metadata_->inputs) { |
505 | if (input_var->type_annotation.defined()) { |
506 | codegen_c_base_.PrintType(input_var->type_annotation, call_args_ss); |
507 | } else { |
508 | codegen_c_base_.PrintType(input_var.dtype(), call_args_ss); |
509 | } |
510 | call_args_ss << " " << tvm::runtime::SanitizeName(input_var->name_hint) << "," ; |
511 | } |
512 | for (unsigned int i = 0; i < metadata_->outputs.size(); ++i) { |
513 | call_args_ss << "void* output" << i << "," ; |
514 | } |
515 | } |
516 | for (const tir::Var& pool_var : metadata_->pools) { |
517 | if (pool_var->type_annotation.defined()) { |
518 | codegen_c_base_.PrintType(pool_var->type_annotation, call_args_ss); |
519 | } else { |
520 | codegen_c_base_.PrintType(pool_var.dtype(), call_args_ss); |
521 | } |
522 | call_args_ss << " " << pool_var->name_hint << "," ; |
523 | } |
524 | for (const String& device : metadata_->devices) { |
525 | call_args_ss << "void* " << device << "," ; |
526 | } |
527 | std::string call_args_str = call_args_ss.str(); |
528 | call_args_str.pop_back(); |
529 | code_ << call_args_str; |
530 | } |
531 | |
532 | code_ << ");\n" ; |
533 | code_ << "int32_t " << entrypoint_name << "(" ; |
534 | { |
535 | std::stringstream call_args_ss; |
536 | if (metadata_->io_pool_allocations.empty()) { |
537 | call_args_ss << "struct " << runtime::get_name_mangled(mod_name, "inputs" ) << "* inputs," ; |
538 | call_args_ss << "struct " << runtime::get_name_mangled(mod_name, "outputs" ) << "* outputs," ; |
539 | } |
540 | if (!metadata_->pools.empty()) { |
541 | bool is_external_pools_present = false; |
542 | for (tir::Var pool_var : metadata_->pools) { |
543 | if (!IsInternalWorkspaceBuffer(pool_var)) { |
544 | is_external_pools_present = true; |
545 | break; |
546 | } |
547 | } |
548 | if (is_external_pools_present) { |
549 | call_args_ss << "struct " << runtime::get_name_mangled(mod_name, "workspace_pools" ) |
550 | << "* workspace_pools," ; |
551 | } |
552 | } |
553 | if (!metadata_->devices.empty()) { |
554 | call_args_ss << "struct " << runtime::get_name_mangled(mod_name, "devices" ) << "* devices," ; |
555 | } |
556 | std::string call_args_str = call_args_ss.str(); |
557 | call_args_str.pop_back(); |
558 | code_ << call_args_str; |
559 | } |
560 | |
561 | code_ << ") {" |
562 | << "return " << run_func << "(" ; |
563 | |
564 | { |
565 | std::stringstream call_args_ss; |
566 | if (metadata_->io_pool_allocations.empty()) { |
567 | for (const auto& input : metadata_->inputs) { |
568 | call_args_ss << "inputs->" << tvm::runtime::SanitizeName(input->name_hint) << "," ; |
569 | } |
570 | for (const auto& output : metadata_->outputs) { |
571 | call_args_ss << "outputs->" << tvm::runtime::SanitizeName(output); |
572 | call_args_ss << "," ; |
573 | } |
574 | } |
575 | |
576 | for (const tir::Var& pool_var : metadata_->pools) { |
577 | String pool_name = metadata_->pool_inputs.value()[pool_var]->pool_info->pool_name; |
578 | if (IsInternalWorkspaceBuffer(pool_var)) { |
579 | call_args_ss << "&" << pool_name << "," ; |
580 | } else { |
581 | call_args_ss << "workspace_pools->" << tvm::runtime::SanitizeName(pool_name) << "," ; |
582 | } |
583 | } |
584 | for (const String& device : metadata_->devices) { |
585 | call_args_ss << "devices->" << device << "," ; |
586 | } |
587 | std::string call_args_str = call_args_ss.str(); |
588 | call_args_str.pop_back(); |
589 | code_ << call_args_str; |
590 | } |
591 | code_ << ");\n" ; |
592 | code_ << "}\n" ; |
593 | } |
594 | |
595 | void GenerateAOTDescriptor() { |
596 | const std::string run_func_suffix = ::tvm::runtime::symbol::tvm_module_main; |
597 | const std::string tvm_entrypoint_suffix = ::tvm::runtime::symbol::tvm_entrypoint_suffix; |
598 | const std::string run_func_mangled = |
599 | runtime::get_name_mangled(metadata_->mod_name, run_func_suffix); |
600 | const std::string entrypoint_mangled = |
601 | runtime::get_name_mangled(metadata_->mod_name, tvm_entrypoint_suffix); |
602 | const std::string network_mangled = runtime::get_name_mangled(metadata_->mod_name, "network" ); |
603 | |
604 | code_ << "#include \"tvm/runtime/c_runtime_api.h\"\n" ; |
605 | code_ << "#ifdef __cplusplus\n" ; |
606 | code_ << "extern \"C\" {\n" ; |
607 | code_ << "#endif\n" ; |
608 | |
609 | GenerateInternalBuffers(); |
610 | |
611 | if (metadata_->unpacked_api) { |
612 | if (metadata_->interface_api == "c" ) { |
613 | GenerateCInterfaceEntrypoint(entrypoint_mangled, run_func_mangled, metadata_->mod_name); |
614 | } else { |
615 | GenerateEntrypointForUnpackedAPI(entrypoint_mangled, run_func_mangled); |
616 | } |
617 | } else { |
618 | ICHECK_EQ(metadata_->interface_api, "packed" ) |
619 | << "Packed interface required for packed operators" ; |
620 | GenerateEntrypointForPackedAPI(entrypoint_mangled, run_func_mangled); |
621 | } |
622 | |
623 | code_ << "#ifdef __cplusplus\n" ; |
624 | code_ << "}\n" ; |
625 | code_ << "#endif\n" ; |
626 | } |
627 | |
628 | void CreateSource() { |
629 | if (runtime_->GetAttr<Bool>("system-lib" ).value_or(Bool(false)) && !func_names_.empty()) { |
630 | CreateFuncRegistry(); |
631 | GenerateCrtSystemLib(); |
632 | } |
633 | if (metadata_.defined() && metadata_->executor == runtime::kTvmExecutorAot) { |
634 | GenerateAOTDescriptor(); |
635 | } |
636 | code_ << ";" ; |
637 | } |
638 | }; |
639 | |
640 | class MetadataSerializer : public AttrVisitor { |
641 | public: |
642 | static constexpr const char* kGlobalSymbol = "kTvmgenMetadata" ; |
643 | using MetadataKind = ::tvm::runtime::metadata::MetadataKind; |
644 | |
645 | MetadataSerializer() : is_first_item_{true} {} |
646 | |
647 | void WriteComma() { |
648 | if (is_first_item_) { |
649 | is_first_item_ = false; |
650 | } else { |
651 | code_ << ", " << std::endl; |
652 | } |
653 | } |
654 | |
655 | void WriteKey(const char* key) { |
656 | if (key != nullptr) { |
657 | code_ << " /* " << key << "*/" ; |
658 | } |
659 | } |
660 | |
661 | void Visit(const char* key, double* value) final { |
662 | WriteComma(); |
663 | code_.setf(std::ios::hex | std::ios::showbase | std::ios::fixed | std::ios::scientific, |
664 | std::ios::basefield | std::ios::showbase | std::ios::floatfield); |
665 | code_ << *value; |
666 | WriteKey(key); |
667 | } |
668 | |
669 | void Visit(const char* key, int64_t* value) final { |
670 | WriteComma(); |
671 | code_ << *value << "L" ; |
672 | WriteKey(key); |
673 | } |
674 | |
675 | void Visit(const char* key, uint64_t* value) final { |
676 | WriteComma(); |
677 | code_ << *value << "UL" ; |
678 | WriteKey(key); |
679 | } |
680 | void Visit(const char* key, int* value) final { |
681 | WriteComma(); |
682 | code_ << *value; |
683 | WriteKey(key); |
684 | } |
685 | void Visit(const char* key, bool* value) final { |
686 | WriteComma(); |
687 | code_ << *value; |
688 | WriteKey(key); |
689 | } |
690 | void Visit(const char* key, std::string* value) final { |
691 | WriteComma(); |
692 | code_ << "\"" << *value << "\"" ; |
693 | WriteKey(key); |
694 | } |
695 | void Visit(const char* key, void** value) final { |
696 | WriteComma(); |
697 | code_ << *value; |
698 | WriteKey(key); |
699 | } |
700 | void Visit(const char* key, DataType* value) final { |
701 | WriteComma(); |
702 | code_ << "{" << value->code() << ", " << value->bits() << ", " << value->lanes() << "}" ; |
703 | WriteKey(key); |
704 | } |
705 | |
706 | // Serialiding NDArray as tuple of len, data |
707 | void Visit(const char* key, runtime::NDArray* value) final { |
708 | WriteComma(); |
709 | std::string bytes; |
710 | dmlc::MemoryStringStream stream(&bytes); |
711 | value->Save(&stream); |
712 | // Serializing length of the data of NDArray |
713 | code_ << stream.Tell(); |
714 | WriteComma(); |
715 | // Serializing NDArray as bytestream |
716 | code_ << "\"" ; |
717 | std::stringstream ss; |
718 | char buf[6] = {0}; |
719 | for (uint8_t c : bytes) { |
720 | snprintf(buf, sizeof(buf), "\\x%02x" , c); |
721 | ss << buf; |
722 | } |
723 | std::string as_bytes(ss.str()); |
724 | code_ << as_bytes; |
725 | code_ << "\"\n" ; |
726 | } |
727 | |
728 | void VisitArray(runtime::metadata::MetadataArray array) { |
729 | auto old_is_first_item = is_first_item_; |
730 | is_first_item_ = true; |
731 | for (unsigned int i = 0; i < array->array.size(); ++i) { |
732 | ObjectRef o = array->array[i]; |
733 | |
734 | switch (array->kind) { |
735 | case MetadataKind::kUint64: { |
736 | int64_t i = Downcast<Integer>(o).IntValue(); |
737 | CHECK_GT(i, 0) |
738 | << "Metadata is of type uint64_t, but array type contains a negative number" ; |
739 | uint64_t ui = static_cast<uint64_t>(i); |
740 | Visit(nullptr, &ui); |
741 | continue; |
742 | } |
743 | case MetadataKind::kInt64: { |
744 | int64_t i = Downcast<Integer>(o).IntValue(); |
745 | Visit(nullptr, &i); |
746 | continue; |
747 | } |
748 | case MetadataKind::kBool: { |
749 | bool b = Downcast<Bool>(o); |
750 | Visit(nullptr, &b); |
751 | break; |
752 | } |
753 | case MetadataKind::kString: { |
754 | std::string s = Downcast<String>(o); |
755 | Visit(nullptr, &s); |
756 | break; |
757 | } |
758 | case MetadataKind::kHandle: |
759 | CHECK(false) << "Don't know how to serialize handle" ; |
760 | break; |
761 | |
762 | case MetadataKind::kMetadata: { |
763 | runtime::metadata::MetadataBase metadata = Downcast<runtime::metadata::MetadataBase>(o); |
764 | std::stringstream i_str; |
765 | i_str << i; |
766 | address_.push_back(i_str.str()); |
767 | Visit(nullptr, &metadata); |
768 | address_.pop_back(); |
769 | break; |
770 | } |
771 | default: |
772 | CHECK(false) << "Unknown MetadataKind for array: " << array->kind; |
773 | break; |
774 | } |
775 | is_first_item_ = false; |
776 | } |
777 | is_first_item_ = old_is_first_item; |
778 | } |
779 | |
780 | void Visit(const char* key, ObjectRef* value) final { |
781 | const runtime::metadata::MetadataArrayNode* arr = |
782 | value->as<runtime::metadata::MetadataArrayNode>(); |
783 | if (arr != nullptr) { |
784 | WriteComma(); |
785 | if (key != nullptr) { |
786 | address_.push_back(key); |
787 | } |
788 | code_ << metadata::AddressFromParts(address_); |
789 | if (key != nullptr) { |
790 | address_.pop_back(); |
791 | } |
792 | return; |
793 | } |
794 | |
795 | runtime::metadata::MetadataBase metadata = Downcast<runtime::metadata::MetadataBase>(*value); |
796 | if (key != nullptr) { // NOTE: outermost call passes nullptr key |
797 | address_.push_back(key); |
798 | } |
799 | WriteComma(); |
800 | code_ << "{\n" ; |
801 | is_first_item_ = true; |
802 | ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this); |
803 | code_ << "}\n" ; |
804 | if (key != nullptr) { // NOTE: outermost call passes nullptr key |
805 | address_.pop_back(); |
806 | } |
807 | } |
808 | |
809 | private: |
810 | void EmitCType(const runtime::metadata::MetadataArrayNode* arr, std::ostream& os) { |
811 | switch (arr->kind) { |
812 | case MetadataKind::kUint64: |
813 | os << "uint64_t" ; |
814 | break; |
815 | case MetadataKind::kInt64: |
816 | os << "int64_t" ; |
817 | break; |
818 | case MetadataKind::kBool: |
819 | os << "bool" ; |
820 | break; |
821 | case MetadataKind::kString: |
822 | os << "const char*" ; |
823 | break; |
824 | case MetadataKind::kHandle: |
825 | os << "void*" ; |
826 | break; |
827 | case MetadataKind::kMetadata: |
828 | os << "struct " << arr->get_element_c_struct_name(); |
829 | break; |
830 | default: |
831 | CHECK(false) << "Unknown kind in MetadataArray: " << arr->kind |
832 | << " (struct_name=" << arr->get_c_struct_name() << ")" ; |
833 | break; |
834 | } |
835 | } |
836 | |
837 | public: |
838 | void CodegenMetadata(::tvm::runtime::metadata::Metadata metadata) { |
839 | decl_ << "#include <inttypes.h>" << std::endl |
840 | << "#include <tvm/runtime/metadata_types.h>" << std::endl |
841 | << "#include <tvm/runtime/c_runtime_api.h>" << std::endl; |
842 | std::vector<metadata::DiscoverArraysVisitor::DiscoveredArray> queue; |
843 | metadata::DiscoverArraysVisitor array_discover{&queue}; |
844 | array_discover.Visit(metadata::kMetadataGlobalSymbol, &metadata); |
845 | |
846 | for (auto item : queue) { |
847 | auto struct_address = std::get<0>(item); |
848 | address_.push_back(struct_address); |
849 | |
850 | auto arr = std::get<1>(item); |
851 | |
852 | // Prepend const with everything except C-string, which needs appending. |
853 | code_ << "static " ; |
854 | if (arr->kind != MetadataKind::kString) { |
855 | code_ << "const " ; |
856 | } |
857 | EmitCType(arr.operator->(), code_); |
858 | if (arr->kind == MetadataKind::kString) { |
859 | code_ << " const" ; |
860 | } |
861 | code_ << " " << struct_address << "[" << arr->array.size() << "] = {" << std::endl; |
862 | is_first_item_ = true; |
863 | |
864 | VisitArray(arr); |
865 | address_.pop_back(); |
866 | code_ << "};" << std::endl; |
867 | } |
868 | |
869 | // Finally, emit overall struct. |
870 | address_.push_back(metadata::kMetadataGlobalSymbol); |
871 | code_ << "static const struct TVMMetadata " << metadata::AddressFromParts(address_) << "[1] = {" |
872 | << std::endl; |
873 | Visit(nullptr, &metadata); |
874 | code_ << "};" << std::endl; |
875 | address_.pop_back(); |
876 | } |
877 | |
878 | std::string GetOutput() { return decl_.str() + code_.str(); } |
879 | |
880 | private: |
881 | std::vector<std::string> address_; |
882 | std::stringstream decl_; |
883 | std::stringstream code_; |
884 | bool is_first_item_; |
885 | std::unordered_set<std::string> generated_struct_decls_; |
886 | std::vector<bool> is_defining_struct_; |
887 | }; |
888 | |
889 | namespace { |
890 | runtime::Module CreateAotMetadataModule(runtime::metadata::Metadata aot_metadata, |
891 | bool is_c_runtime) { |
892 | MetadataSerializer serializer; |
893 | serializer.CodegenMetadata(aot_metadata); |
894 | std::stringstream lookup_func; |
895 | std::string get_c_metadata_func_name; |
896 | |
897 | // NOTE: mangling is not needed in the c++ runtime because the function |
898 | // name is looked-up via LibraryModule. |
899 | // TODO(alanmacd): unify these two approaches |
900 | |
901 | if (is_c_runtime == true) { |
902 | get_c_metadata_func_name = runtime::get_name_mangled( |
903 | aot_metadata->mod_name(), ::tvm::runtime::symbol::tvm_get_c_metadata); |
904 | } else { |
905 | get_c_metadata_func_name = ::tvm::runtime::symbol::tvm_get_c_metadata; |
906 | } |
907 | |
908 | lookup_func << "#ifdef __cplusplus\n" |
909 | << "extern \"C\"\n" |
910 | << "#endif\n" ; |
911 | |
912 | lookup_func << "TVM_DLL int32_t " << get_c_metadata_func_name |
913 | << "(TVMValue* arg_values, int* arg_tcodes, int " |
914 | "num_args, TVMValue* ret_values, int* ret_tcodes, void* resource_handle) {" |
915 | << std::endl; |
916 | lookup_func << " ret_values[0].v_handle = (void*) &" << MetadataSerializer::kGlobalSymbol |
917 | << ";" << std::endl; |
918 | lookup_func << " ret_tcodes[0] = kTVMOpaqueHandle;" << std::endl; |
919 | lookup_func << " return 0;" << std::endl; |
920 | lookup_func << "};" << std::endl; |
921 | std::vector<String> func_names{get_c_metadata_func_name}; |
922 | return CSourceModuleCreate(serializer.GetOutput() + lookup_func.str(), "c" , func_names, |
923 | Array<String>()); |
924 | } |
925 | } // namespace |
926 | |
927 | runtime::Module CreateCSourceCrtMetadataModule(const Array<runtime::Module>& modules, Target target, |
928 | relay::Runtime runtime, |
929 | relay::backend::ExecutorCodegenMetadata metadata, |
930 | runtime::metadata::Metadata aot_metadata) { |
931 | Array<runtime::Module> final_modules(modules); |
932 | Array<String> func_names; |
933 | |
934 | if (metadata.defined()) { |
935 | if (metadata->executor == "aot" ) { |
936 | if (aot_metadata.defined()) { |
937 | final_modules.push_back(CreateAotMetadataModule(aot_metadata, true)); |
938 | } |
939 | |
940 | // add the run function (typically "tvmgen_default_run") to function registry |
941 | // when using AOT executor |
942 | std::string run_func = runtime::get_name_mangled(metadata->mod_name, "run" ); |
943 | func_names.push_back(run_func); |
944 | } |
945 | } |
946 | |
947 | for (runtime::Module mod : final_modules) { |
948 | auto pf_funcs = mod.GetFunction("get_func_names" ); |
949 | if (pf_funcs != nullptr) { |
950 | Array<String> func_names_ = pf_funcs(); |
951 | for (const auto& fname : func_names_) { |
952 | func_names.push_back(fname); |
953 | } |
954 | } |
955 | } |
956 | |
957 | auto n = make_object<CSourceCrtMetadataModuleNode>(func_names, "c" , target, runtime, metadata); |
958 | auto csrc_metadata_module = runtime::Module(n); |
959 | for (const auto& mod : final_modules) { |
960 | csrc_metadata_module.Import(mod); |
961 | } |
962 | |
963 | return std::move(csrc_metadata_module); |
964 | } |
965 | |
966 | runtime::Module CreateCSourceCppMetadataModule(runtime::metadata::Metadata metadata) { |
967 | MetadataSerializer serializer; |
968 | serializer.CodegenMetadata(metadata); |
969 | std::stringstream lookup_func; |
970 | lookup_func << "#ifdef __cplusplus\n" |
971 | << "extern \"C\"\n" |
972 | << "#endif\n" ; |
973 | |
974 | lookup_func << "TVM_DLL int32_t " << ::tvm::runtime::symbol::tvm_get_c_metadata |
975 | << "(TVMValue* arg_values, int* arg_tcodes, int " |
976 | "num_args, TVMValue* ret_values, int* ret_tcodes, void* resource_handle) {" |
977 | << std::endl; |
978 | lookup_func << " ret_values[0].v_handle = (void*) &" << metadata::kMetadataGlobalSymbol << ";" |
979 | << std::endl; |
980 | lookup_func << " ret_tcodes[0] = kTVMOpaqueHandle;" << std::endl; |
981 | lookup_func << " return 0;" << std::endl; |
982 | lookup_func << "};" << std::endl; |
983 | |
984 | auto mod = MetadataModuleCreate(metadata); |
985 | mod->Import(CreateAotMetadataModule(metadata, false)); |
986 | return mod; |
987 | } |
988 | |
989 | // supports limited save without cross compile |
990 | class DeviceSourceModuleNode final : public runtime::ModuleNode { |
991 | public: |
992 | DeviceSourceModuleNode(std::string data, std::string fmt, |
993 | std::unordered_map<std::string, FunctionInfo> fmap, std::string type_key, |
994 | std::function<std::string(const std::string&)> fget_source) |
995 | : data_(data), fmt_(fmt), fmap_(fmap), type_key_(type_key), fget_source_(fget_source) {} |
996 | |
997 | PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final { |
998 | LOG(FATAL) << "Source module cannot execute, to get executable module" |
999 | << " build TVM with \'" << fmt_ << "\' runtime support" ; |
1000 | return PackedFunc(); |
1001 | } |
1002 | |
1003 | std::string GetSource(const std::string& format) final { |
1004 | if (fget_source_ != nullptr) { |
1005 | return fget_source_(format); |
1006 | } else { |
1007 | return data_; |
1008 | } |
1009 | } |
1010 | |
1011 | const char* type_key() const final { return type_key_.c_str(); } |
1012 | |
1013 | void SaveToFile(const std::string& file_name, const std::string& format) final { |
1014 | std::string fmt = GetFileFormat(file_name, format); |
1015 | ICHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; |
1016 | std::string meta_file = GetMetaFilePath(file_name); |
1017 | SaveMetaDataToFile(meta_file, fmap_); |
1018 | SaveBinaryToFile(file_name, data_); |
1019 | } |
1020 | |
1021 | void SaveToBinary(dmlc::Stream* stream) final { |
1022 | stream->Write(fmt_); |
1023 | stream->Write(fmap_); |
1024 | stream->Write(data_); |
1025 | } |
1026 | |
1027 | private: |
1028 | std::string data_; |
1029 | std::string fmt_; |
1030 | std::unordered_map<std::string, FunctionInfo> fmap_; |
1031 | std::string type_key_; |
1032 | std::function<std::string(const std::string&)> fget_source_; |
1033 | }; |
1034 | |
1035 | runtime::Module DeviceSourceModuleCreate( |
1036 | std::string data, std::string fmt, std::unordered_map<std::string, FunctionInfo> fmap, |
1037 | std::string type_key, std::function<std::string(const std::string&)> fget_source) { |
1038 | auto n = make_object<DeviceSourceModuleNode>(data, fmt, fmap, type_key, fget_source); |
1039 | return runtime::Module(n); |
1040 | } |
1041 | |
1042 | TVM_REGISTER_GLOBAL("runtime.SourceModuleCreate" ).set_body_typed(SourceModuleCreate); |
1043 | |
1044 | TVM_REGISTER_GLOBAL("runtime.CSourceModuleCreate" ) |
1045 | .set_body_typed([](String code, String fmt, Array<String> func_names, |
1046 | Array<String> const_vars) { |
1047 | return CSourceModuleCreate(code, fmt, func_names, const_vars); |
1048 | }); |
1049 | |
1050 | TVM_REGISTER_GLOBAL("runtime.CreateCSourceCrtMetadataModule" ) |
1051 | .set_body_typed([](const Array<runtime::Module>& modules, Target target, |
1052 | relay::Runtime runtime) { |
1053 | // Note that we don't need metadata when we compile a single operator |
1054 | return CreateCSourceCrtMetadataModule(modules, target, runtime, |
1055 | relay::backend::ExecutorCodegenMetadata(), |
1056 | runtime::metadata::Metadata()); |
1057 | }); |
1058 | |
1059 | } // namespace codegen |
1060 | } // namespace tvm |
1061 | |