1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file source_module.cc
22 * \brief Source code module, only for viewing
23 */
24#include "source_module.h"
25
26#include <dmlc/memory_io.h>
27#include <tvm/runtime/metadata.h>
28#include <tvm/runtime/module.h>
29#include <tvm/runtime/name_transforms.h>
30#include <tvm/runtime/ndarray.h>
31#include <tvm/runtime/packed_func.h>
32#include <tvm/runtime/registry.h>
33
34#include <algorithm>
35#include <functional>
36#include <numeric>
37#include <string>
38#include <unordered_map>
39#include <unordered_set>
40#include <utility>
41#include <vector>
42
43#include "../../relay/backend/name_transforms.h"
44#include "../../runtime/file_utils.h"
45#include "../../support/str_escape.h"
46#include "../func_registry_generator.h"
47#include "../metadata.h"
48#include "../metadata_utils.h"
49#include "codegen_params.h"
50#include "codegen_source_base.h"
51#include "tvm/relay/executor.h"
52
53namespace tvm {
54namespace codegen {
55
56using runtime::PackedFunc;
57using runtime::TVMArgs;
58using runtime::TVMRetValue;
59
60using runtime::FunctionInfo;
61using runtime::GetFileFormat;
62using runtime::GetMetaFilePath;
63using runtime::SaveBinaryToFile;
64
65// Simulator function
66class SourceModuleNode : public runtime::ModuleNode {
67 public:
68 SourceModuleNode(std::string code, std::string fmt) : code_(code), fmt_(fmt) {}
69 const char* type_key() const final { return "source"; }
70
71 PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
72 LOG(FATAL) << "Source module cannot execute, to get executable module"
73 << " build TVM with \'" << fmt_ << "\' runtime support";
74 return PackedFunc();
75 }
76
77 std::string GetSource(const std::string& format) final { return code_; }
78
79 std::string GetFormat() { return fmt_; }
80
81 protected:
82 std::string code_;
83 std::string fmt_;
84};
85
86runtime::Module SourceModuleCreate(std::string code, std::string fmt) {
87 auto n = make_object<SourceModuleNode>(code, fmt);
88 return runtime::Module(n);
89}
90
91// Simulator function
92class CSourceModuleNode : public runtime::ModuleNode {
93 public:
94 CSourceModuleNode(const std::string& code, const std::string& fmt,
95 const Array<String>& func_names, const Array<String>& const_vars)
96 : code_(code), fmt_(fmt), const_vars_(const_vars), func_names_(func_names) {}
97 const char* type_key() const final { return "c"; }
98
99 PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
100 // Currently c-source module is used as demonstration purposes with binary metadata module
101 // that expects get_symbol interface. When c-source module is used as external module, it
102 // will only contain one function. However, when its used as an internal module (e.g., target
103 // "c") it can have many functions.
104 if (name == "get_symbol") {
105 return PackedFunc(
106 [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->func_names_[0]; });
107 } else if (name == "get_const_vars") {
108 return PackedFunc(
109 [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->const_vars_; });
110 } else if (name == "get_func_names") {
111 return PackedFunc(
112 [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->func_names_; });
113 } else {
114 return PackedFunc(nullptr);
115 }
116 }
117
118 std::string GetSource(const std::string& format) final { return code_; }
119
120 std::string GetFormat() { return fmt_; }
121
122 void SaveToFile(const std::string& file_name, const std::string& format) final {
123 std::string fmt = GetFileFormat(file_name, format);
124 std::string meta_file = GetMetaFilePath(file_name);
125 if (fmt == "c" || fmt == "cc" || fmt == "cpp" || fmt == "cu") {
126 ICHECK_NE(code_.length(), 0);
127 SaveBinaryToFile(file_name, code_);
128 } else {
129 ICHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_;
130 }
131 }
132
133 bool IsDSOExportable() const final { return true; }
134
135 bool ImplementsFunction(const String& name, bool query_imports) final {
136 return std::find(func_names_.begin(), func_names_.end(), name) != func_names_.end();
137 }
138
139 protected:
140 std::string code_;
141 std::string fmt_;
142 Array<String> const_vars_;
143 Array<String> func_names_;
144};
145
146runtime::Module CSourceModuleCreate(const String& code, const String& fmt,
147 const Array<String>& func_names,
148 const Array<String>& const_vars) {
149 auto n = make_object<CSourceModuleNode>(code.operator std::string(), fmt.operator std::string(),
150 func_names, const_vars);
151 return runtime::Module(n);
152}
153
154/*!
155 * \brief A concrete class to get access to base methods of CodegenSourceBase.
156 *
157 * This class exist to get access to methods of CodegenSourceBase without duplicating
158 * them. Therefore, keeping alignment with how codegen and source_module here generates
159 * code.
160 */
161class ConcreteCodegenSourceBase : public CodeGenSourceBase {
162 /*!
163 * \brief Do nothing as this class exist to get access to methods of CodeGenSourceBase
164 */
165 void PrintSSAAssign(const std::string& target, const std::string& src, DataType t) final {
166 return;
167 }
168};
169
170class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
171 public:
172 CSourceCrtMetadataModuleNode(const Array<String>& func_names, const std::string& fmt,
173 Target target, relay::Runtime runtime,
174 relay::backend::ExecutorCodegenMetadata metadata)
175 : fmt_(fmt),
176 func_names_(func_names),
177 target_(target),
178 runtime_(runtime),
179 metadata_(metadata) {
180 CreateSource();
181 }
182 const char* type_key() const final { return "c"; }
183
184 std::string GetSource(const std::string& format) final { return code_.str(); }
185
186 std::string GetFormat() { return fmt_; }
187 PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
188 return PackedFunc();
189 }
190
191 void SaveToFile(const std::string& file_name, const std::string& format) final {
192 std::string fmt = GetFileFormat(file_name, format);
193 std::string meta_file = GetMetaFilePath(file_name);
194 if (fmt == "c" || fmt == "cc" || fmt == "cpp") {
195 auto code_str = code_.str();
196 ICHECK_NE(code_str.length(), 0);
197 SaveBinaryToFile(file_name, code_str);
198 } else {
199 ICHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_;
200 }
201 }
202
203 bool IsDSOExportable() const final { return true; }
204
205 bool ImplementsFunction(const String& name, bool query_imports) final {
206 return std::find(func_names_.begin(), func_names_.end(), name) != func_names_.end();
207 }
208
209 protected:
210 std::stringstream code_;
211 std::string fmt_;
212 Array<String> func_names_;
213 Target target_;
214 relay::Runtime runtime_;
215 relay::backend::ExecutorCodegenMetadata metadata_;
216 ConcreteCodegenSourceBase codegen_c_base_;
217
218 void CreateFuncRegistry() {
219 code_ << "#include <tvm/runtime/crt/module.h>\n";
220 for (const auto& fname : func_names_) {
221 code_ << "#ifdef __cplusplus\n";
222 code_ << "extern \"C\"\n";
223 code_ << "#endif\n";
224 code_ << "TVM_DLL int32_t " << fname.data();
225 code_ << "(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, "
226 "int* out_type_code, void* resource_handle);\n";
227 }
228 code_ << "static TVMBackendPackedCFunc _tvm_func_array[] = {\n";
229 for (auto f : func_names_) {
230 code_ << " (TVMBackendPackedCFunc)" << f << ",\n";
231 }
232 code_ << "};\n";
233 auto registry = target::GenerateFuncRegistryNames(func_names_);
234 code_ << "static const TVMFuncRegistry _tvm_func_registry = {\n"
235 << " \"" << ::tvm::support::StrEscape(registry.data(), registry.size(), true) << "\","
236 << " _tvm_func_array,\n"
237 << "};\n";
238 }
239
240 void GenerateCrtSystemLib() {
241 code_ << "static const TVMModule _tvm_system_lib = {\n"
242 << " &_tvm_func_registry,\n"
243 << "};\n"
244 << "const TVMModule* TVMSystemLibEntryPoint(void) {\n"
245 << " return &_tvm_system_lib;\n"
246 << "}\n";
247 }
248
249 String GenerateDLTensorStructWrapper(String reference_arg) {
250 code_ << "DLTensor " << reference_arg << "_dltensor = {\n";
251 code_ << ".data = &" << reference_arg << "\n";
252 code_ << "};\n";
253 code_ << "TVMValue " << reference_arg << "_tvm_value = {\n";
254 code_ << ".v_handle = &" << reference_arg << "_dltensor\n";
255 code_ << "};\n";
256 return reference_arg + "_tvm_value";
257 }
258
259 void GenerateInternalBuffers() {
260 if (metadata_->pool_inputs.defined()) {
261 for (const auto& kv : metadata_->pool_inputs.value()) {
262 tir::usmp::AllocatedPoolInfo allocated_pool_info = kv.second;
263 if (allocated_pool_info->pool_info->is_internal) {
264 if (const auto* pool_info = allocated_pool_info->pool_info.as<ConstantPoolInfoNode>()) {
265 GenerateConstantBuffer(pool_info, allocated_pool_info->allocated_size->value);
266 } else {
267 GenerateWorkspaceBuffer(allocated_pool_info->pool_info.as<WorkspacePoolInfoNode>(),
268 allocated_pool_info->allocated_size->value);
269 }
270 }
271 }
272 }
273 }
274
275 void GenerateIOWorkspaceMapFunction(const std::string& struct_type,
276 const std::string& function_name,
277 const Array<String>& tensor_names) {
278 std::string map_function = runtime::get_name_mangled(metadata_->mod_name, function_name);
279 code_ << "struct " << struct_type << " " << map_function << "(\n";
280 std::string pools_struct = runtime::get_name_mangled(metadata_->mod_name, "workspace_pools");
281 code_ << " struct " << pools_struct << "* workspace_pools\n";
282 code_ << "\n){\n";
283 code_ << "struct " << struct_type << " ret = {\n";
284 for (const String& name : tensor_names) {
285 tir::usmp::PoolAllocation pool_allocation = metadata_->io_pool_allocations[name];
286 code_ << "\t." << name << " = "
287 << "&((uint8_t*)workspace_pools->" << pool_allocation->pool_info->pool_name << ")["
288 << pool_allocation->byte_offset << "],\n";
289 }
290 code_ << "};\n";
291 code_ << "return ret;\n";
292 code_ << "}\n\n";
293 }
294
295 void GenerateConstantBuffer(const ConstantPoolInfoNode* pool_info, size_t allocated_size) {
296 size_t ord = 0;
297 if (pool_info->constant_info_array.size() > 0) {
298 // Pool is RO, form an initialized struct
299 code_ << "__attribute__((section(\".rodata.tvm\"), ";
300 code_ << "))\n";
301 code_ << "static struct " << pool_info->pool_name << " {\n";
302 // emit struct field names
303 std::vector<ConstantInfo> const_info_vec(pool_info->constant_info_array.begin(),
304 pool_info->constant_info_array.end());
305 std::sort(const_info_vec.begin(), const_info_vec.end(),
306 [](const ConstantInfo& a, const ConstantInfo& b) {
307 return a->byte_offset->value < b->byte_offset->value;
308 });
309 for (const auto& const_info : const_info_vec) {
310 const auto& data = const_info->data;
311 const auto& offs = const_info->byte_offset;
312 int64_t num_elements = std::accumulate(data.Shape().begin(), data.Shape().end(), 1,
313 std::multiplies<int64_t>());
314 code_ << " ";
315 codegen_c_base_.PrintType(data.DataType(), code_);
316 code_ << " " << const_info->name_hint << "[" << num_elements << "] __attribute__(("
317 << (ord++ ? "packed, " : "") << "aligned(" << metadata_->constant_alignment << ")));";
318 code_ << " // " << num_elements * data.DataType().bytes()
319 << " bytes, aligned offset: " << offs << "\n";
320 }
321 code_ << "} " << pool_info->pool_name << " = {\n";
322
323 // emit struct field initialization data
324 for (const auto& const_info : const_info_vec) {
325 code_ << " ." << const_info->name_hint << " = {\n";
326 codegen::NDArrayDataToC(const_info->data, 4, code_);
327 code_ << " },\n";
328 }
329 code_ << "};";
330 code_ << "// of total size " << allocated_size << " bytes\n";
331 } else {
332 LOG(FATAL) << "No constant data in constant pool found " << GetRef<ObjectRef>(pool_info);
333 }
334 }
335
336 void GenerateWorkspaceBuffer(const WorkspacePoolInfoNode* pool_info, size_t allocated_size) {
337 code_ << "__attribute__((section(\".bss.noinit.tvm\"), ";
338 code_ << "aligned(" << metadata_->workspace_alignment << ")))\n";
339 code_ << "static uint8_t " << pool_info->pool_name << "[";
340 code_ << allocated_size << "];\n";
341 }
342
343 bool IsInternalWorkspaceBuffer(const tir::Var& pool_var) {
344 if (metadata_->pool_inputs.defined()) {
345 Map<tir::Var, tir::usmp::AllocatedPoolInfo> allocated_pool_infos =
346 metadata_->pool_inputs.value();
347 if (allocated_pool_infos.find(pool_var) != allocated_pool_infos.end()) {
348 tir::usmp::AllocatedPoolInfo allocate_pool_info = allocated_pool_infos[pool_var];
349 if (allocate_pool_info->pool_info->is_internal) {
350 return true;
351 }
352 }
353 }
354 return false;
355 }
356
357 void GenerateEntrypointForUnpackedAPI(const std::string& entrypoint_name,
358 const std::string& run_func) {
359 code_ << "TVM_DLL int32_t " << run_func << "(";
360
361 {
362 std::stringstream call_args_ss;
363 if (metadata_->io_pool_allocations.empty()) {
364 for (const tir::Var& input_var : metadata_->inputs) {
365 if (input_var->type_annotation.defined()) {
366 codegen_c_base_.PrintType(input_var->type_annotation, call_args_ss);
367 } else {
368 codegen_c_base_.PrintType(input_var.dtype(), call_args_ss);
369 }
370 call_args_ss << " " << input_var->name_hint << ",";
371 }
372 for (unsigned int i = 0; i < metadata_->outputs.size(); ++i) {
373 call_args_ss << "void* output" << i << ",";
374 }
375 }
376 for (const tir::Var& pool_var : metadata_->pools) {
377 if (pool_var->type_annotation.defined()) {
378 codegen_c_base_.PrintType(pool_var->type_annotation, call_args_ss);
379 } else {
380 codegen_c_base_.PrintType(pool_var.dtype(), call_args_ss);
381 }
382 call_args_ss << " " << pool_var->name_hint << ",";
383 }
384 std::string call_args_str = call_args_ss.str();
385 call_args_str.pop_back();
386 code_ << call_args_str;
387 }
388
389 code_ << ");\n";
390 code_ << "int32_t " << entrypoint_name;
391 code_ << "(void* args, void* type_code, int num_args, void* out_value, void* "
392 "out_type_code, void* resource_handle) {\n";
393 code_ << "return " << run_func << "(";
394
395 {
396 std::stringstream call_args_ss;
397 if (metadata_->io_pool_allocations.empty()) {
398 for (unsigned int i = 0; i < metadata_->inputs.size(); ++i) {
399 call_args_ss << "((DLTensor*)(((TVMValue*)args)[" << i << "].v_handle))[0].data,";
400 }
401 for (unsigned int i = 0; i < metadata_->outputs.size(); ++i) {
402 int j = metadata_->inputs.size() + i;
403 call_args_ss << "((DLTensor*)(((TVMValue*)args)[" << j << "].v_handle))[0].data,";
404 }
405 }
406 for (const tir::Var& pool_var : metadata_->pools) {
407 if (IsInternalWorkspaceBuffer(pool_var)) {
408 call_args_ss << "&" << metadata_->pool_inputs.value()[pool_var]->pool_info->pool_name
409 << ",";
410 }
411 }
412 std::string call_args_str = call_args_ss.str();
413 call_args_str.pop_back();
414 code_ << call_args_str;
415 code_ << ");\n";
416 code_ << "}\n";
417 }
418 }
419
420 std::unordered_map<int, ObjectRef> GenerateRunFuncToEntryPointArgMap() {
421 std::unordered_map<int, ObjectRef> run_func_to_entry_point_args;
422 int entrypoint_arg_count = 0;
423 int run_func_arg_count = 0;
424
425 if (metadata_->io_pool_allocations.empty()) {
426 for (unsigned int i = 0; i < metadata_->inputs.size(); i++) {
427 run_func_to_entry_point_args[run_func_arg_count] = Integer(entrypoint_arg_count);
428 entrypoint_arg_count++;
429 run_func_arg_count++;
430 }
431 for (unsigned int i = 0; i < metadata_->outputs.size(); i++) {
432 run_func_to_entry_point_args[run_func_arg_count] = Integer(entrypoint_arg_count);
433 entrypoint_arg_count++;
434 run_func_arg_count++;
435 }
436 }
437 for (const tir::Var& pool_var : metadata_->pools) {
438 if (IsInternalWorkspaceBuffer(pool_var)) {
439 tir::usmp::AllocatedPoolInfo allocated_pool_info = metadata_->pool_inputs.value()[pool_var];
440 run_func_to_entry_point_args[run_func_arg_count] =
441 allocated_pool_info->pool_info->pool_name;
442 run_func_arg_count++;
443 }
444 }
445 return run_func_to_entry_point_args;
446 }
447
448 void GenerateEntrypointForPackedAPI(const std::string& entrypoint_name,
449 const std::string& run_func) {
450 code_ << "TVM_DLL int32_t " << run_func;
451 code_ << "(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, int* "
452 "out_type_code, void* resource_handle);\n\n";
453
454 code_ << "int32_t " << entrypoint_name;
455 code_ << "(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, int* "
456 "out_type_code, void* resource_handle) {\n";
457
458 // We are creating a copy of the set of pointers
459 size_t number_of_io_tensors = metadata_->inputs.size() + metadata_->outputs.size() +
460 metadata_->pools.size() - metadata_->io_pool_allocations.size();
461 code_ << "TVMValue tensors[" << number_of_io_tensors << "];\n";
462
463 std::unordered_map<int, ObjectRef> run_func_to_entry_point_args =
464 GenerateRunFuncToEntryPointArgMap();
465 for (unsigned int i = 0; i < number_of_io_tensors; i++) {
466 if (run_func_to_entry_point_args.find(i) != run_func_to_entry_point_args.end()) {
467 if (run_func_to_entry_point_args[i]->IsInstance<StringObj>()) {
468 String pool_name = Downcast<String>(run_func_to_entry_point_args[i]);
469 String pool_name_tvmv = GenerateDLTensorStructWrapper(pool_name);
470 code_ << "tensors[" << i << "] = " << pool_name_tvmv << ";\n";
471 } else {
472 code_ << "tensors[" << i << "] = ((TVMValue*)args)[" << run_func_to_entry_point_args[i]
473 << "];\n";
474 }
475 }
476 }
477
478 code_ << "return " << run_func;
479 code_ << "((void*)tensors, type_code, num_args, out_value, out_type_code, resource_handle);\n";
480 code_ << "}\n";
481 }
482
483 static int isNotAlnum(char c) { return !std::isalnum(c); }
484
485 void GenerateCInterfaceEntrypoint(const std::string& entrypoint_name, const std::string& run_func,
486 const std::string& mod_name) {
487 code_ << "#include <" << mod_name << ".h>\n";
488 if (!metadata_->io_pool_allocations.empty()) {
489 const std::string input_struct_type =
490 runtime::get_name_mangled(metadata_->mod_name, "inputs");
491 Array<String> input_tensor_names;
492 for (const tir::Var& input_var : metadata_->inputs) {
493 input_tensor_names.push_back(input_var->name_hint);
494 }
495 GenerateIOWorkspaceMapFunction(input_struct_type, "map_inputs", input_tensor_names);
496 const std::string output_struct_type =
497 runtime::get_name_mangled(metadata_->mod_name, "outputs");
498 GenerateIOWorkspaceMapFunction(output_struct_type, "map_outputs", metadata_->outputs);
499 }
500 code_ << "TVM_DLL int32_t " << run_func << "(";
501 {
502 std::stringstream call_args_ss;
503 if (metadata_->io_pool_allocations.empty()) {
504 for (const tir::Var& input_var : metadata_->inputs) {
505 if (input_var->type_annotation.defined()) {
506 codegen_c_base_.PrintType(input_var->type_annotation, call_args_ss);
507 } else {
508 codegen_c_base_.PrintType(input_var.dtype(), call_args_ss);
509 }
510 call_args_ss << " " << tvm::runtime::SanitizeName(input_var->name_hint) << ",";
511 }
512 for (unsigned int i = 0; i < metadata_->outputs.size(); ++i) {
513 call_args_ss << "void* output" << i << ",";
514 }
515 }
516 for (const tir::Var& pool_var : metadata_->pools) {
517 if (pool_var->type_annotation.defined()) {
518 codegen_c_base_.PrintType(pool_var->type_annotation, call_args_ss);
519 } else {
520 codegen_c_base_.PrintType(pool_var.dtype(), call_args_ss);
521 }
522 call_args_ss << " " << pool_var->name_hint << ",";
523 }
524 for (const String& device : metadata_->devices) {
525 call_args_ss << "void* " << device << ",";
526 }
527 std::string call_args_str = call_args_ss.str();
528 call_args_str.pop_back();
529 code_ << call_args_str;
530 }
531
532 code_ << ");\n";
533 code_ << "int32_t " << entrypoint_name << "(";
534 {
535 std::stringstream call_args_ss;
536 if (metadata_->io_pool_allocations.empty()) {
537 call_args_ss << "struct " << runtime::get_name_mangled(mod_name, "inputs") << "* inputs,";
538 call_args_ss << "struct " << runtime::get_name_mangled(mod_name, "outputs") << "* outputs,";
539 }
540 if (!metadata_->pools.empty()) {
541 bool is_external_pools_present = false;
542 for (tir::Var pool_var : metadata_->pools) {
543 if (!IsInternalWorkspaceBuffer(pool_var)) {
544 is_external_pools_present = true;
545 break;
546 }
547 }
548 if (is_external_pools_present) {
549 call_args_ss << "struct " << runtime::get_name_mangled(mod_name, "workspace_pools")
550 << "* workspace_pools,";
551 }
552 }
553 if (!metadata_->devices.empty()) {
554 call_args_ss << "struct " << runtime::get_name_mangled(mod_name, "devices") << "* devices,";
555 }
556 std::string call_args_str = call_args_ss.str();
557 call_args_str.pop_back();
558 code_ << call_args_str;
559 }
560
561 code_ << ") {"
562 << "return " << run_func << "(";
563
564 {
565 std::stringstream call_args_ss;
566 if (metadata_->io_pool_allocations.empty()) {
567 for (const auto& input : metadata_->inputs) {
568 call_args_ss << "inputs->" << tvm::runtime::SanitizeName(input->name_hint) << ",";
569 }
570 for (const auto& output : metadata_->outputs) {
571 call_args_ss << "outputs->" << tvm::runtime::SanitizeName(output);
572 call_args_ss << ",";
573 }
574 }
575
576 for (const tir::Var& pool_var : metadata_->pools) {
577 String pool_name = metadata_->pool_inputs.value()[pool_var]->pool_info->pool_name;
578 if (IsInternalWorkspaceBuffer(pool_var)) {
579 call_args_ss << "&" << pool_name << ",";
580 } else {
581 call_args_ss << "workspace_pools->" << tvm::runtime::SanitizeName(pool_name) << ",";
582 }
583 }
584 for (const String& device : metadata_->devices) {
585 call_args_ss << "devices->" << device << ",";
586 }
587 std::string call_args_str = call_args_ss.str();
588 call_args_str.pop_back();
589 code_ << call_args_str;
590 }
591 code_ << ");\n";
592 code_ << "}\n";
593 }
594
595 void GenerateAOTDescriptor() {
596 const std::string run_func_suffix = ::tvm::runtime::symbol::tvm_module_main;
597 const std::string tvm_entrypoint_suffix = ::tvm::runtime::symbol::tvm_entrypoint_suffix;
598 const std::string run_func_mangled =
599 runtime::get_name_mangled(metadata_->mod_name, run_func_suffix);
600 const std::string entrypoint_mangled =
601 runtime::get_name_mangled(metadata_->mod_name, tvm_entrypoint_suffix);
602 const std::string network_mangled = runtime::get_name_mangled(metadata_->mod_name, "network");
603
604 code_ << "#include \"tvm/runtime/c_runtime_api.h\"\n";
605 code_ << "#ifdef __cplusplus\n";
606 code_ << "extern \"C\" {\n";
607 code_ << "#endif\n";
608
609 GenerateInternalBuffers();
610
611 if (metadata_->unpacked_api) {
612 if (metadata_->interface_api == "c") {
613 GenerateCInterfaceEntrypoint(entrypoint_mangled, run_func_mangled, metadata_->mod_name);
614 } else {
615 GenerateEntrypointForUnpackedAPI(entrypoint_mangled, run_func_mangled);
616 }
617 } else {
618 ICHECK_EQ(metadata_->interface_api, "packed")
619 << "Packed interface required for packed operators";
620 GenerateEntrypointForPackedAPI(entrypoint_mangled, run_func_mangled);
621 }
622
623 code_ << "#ifdef __cplusplus\n";
624 code_ << "}\n";
625 code_ << "#endif\n";
626 }
627
628 void CreateSource() {
629 if (runtime_->GetAttr<Bool>("system-lib").value_or(Bool(false)) && !func_names_.empty()) {
630 CreateFuncRegistry();
631 GenerateCrtSystemLib();
632 }
633 if (metadata_.defined() && metadata_->executor == runtime::kTvmExecutorAot) {
634 GenerateAOTDescriptor();
635 }
636 code_ << ";";
637 }
638};
639
640class MetadataSerializer : public AttrVisitor {
641 public:
642 static constexpr const char* kGlobalSymbol = "kTvmgenMetadata";
643 using MetadataKind = ::tvm::runtime::metadata::MetadataKind;
644
645 MetadataSerializer() : is_first_item_{true} {}
646
647 void WriteComma() {
648 if (is_first_item_) {
649 is_first_item_ = false;
650 } else {
651 code_ << ", " << std::endl;
652 }
653 }
654
655 void WriteKey(const char* key) {
656 if (key != nullptr) {
657 code_ << " /* " << key << "*/";
658 }
659 }
660
661 void Visit(const char* key, double* value) final {
662 WriteComma();
663 code_.setf(std::ios::hex | std::ios::showbase | std::ios::fixed | std::ios::scientific,
664 std::ios::basefield | std::ios::showbase | std::ios::floatfield);
665 code_ << *value;
666 WriteKey(key);
667 }
668
669 void Visit(const char* key, int64_t* value) final {
670 WriteComma();
671 code_ << *value << "L";
672 WriteKey(key);
673 }
674
675 void Visit(const char* key, uint64_t* value) final {
676 WriteComma();
677 code_ << *value << "UL";
678 WriteKey(key);
679 }
680 void Visit(const char* key, int* value) final {
681 WriteComma();
682 code_ << *value;
683 WriteKey(key);
684 }
685 void Visit(const char* key, bool* value) final {
686 WriteComma();
687 code_ << *value;
688 WriteKey(key);
689 }
690 void Visit(const char* key, std::string* value) final {
691 WriteComma();
692 code_ << "\"" << *value << "\"";
693 WriteKey(key);
694 }
695 void Visit(const char* key, void** value) final {
696 WriteComma();
697 code_ << *value;
698 WriteKey(key);
699 }
700 void Visit(const char* key, DataType* value) final {
701 WriteComma();
702 code_ << "{" << value->code() << ", " << value->bits() << ", " << value->lanes() << "}";
703 WriteKey(key);
704 }
705
706 // Serialiding NDArray as tuple of len, data
707 void Visit(const char* key, runtime::NDArray* value) final {
708 WriteComma();
709 std::string bytes;
710 dmlc::MemoryStringStream stream(&bytes);
711 value->Save(&stream);
712 // Serializing length of the data of NDArray
713 code_ << stream.Tell();
714 WriteComma();
715 // Serializing NDArray as bytestream
716 code_ << "\"";
717 std::stringstream ss;
718 char buf[6] = {0};
719 for (uint8_t c : bytes) {
720 snprintf(buf, sizeof(buf), "\\x%02x", c);
721 ss << buf;
722 }
723 std::string as_bytes(ss.str());
724 code_ << as_bytes;
725 code_ << "\"\n";
726 }
727
728 void VisitArray(runtime::metadata::MetadataArray array) {
729 auto old_is_first_item = is_first_item_;
730 is_first_item_ = true;
731 for (unsigned int i = 0; i < array->array.size(); ++i) {
732 ObjectRef o = array->array[i];
733
734 switch (array->kind) {
735 case MetadataKind::kUint64: {
736 int64_t i = Downcast<Integer>(o).IntValue();
737 CHECK_GT(i, 0)
738 << "Metadata is of type uint64_t, but array type contains a negative number";
739 uint64_t ui = static_cast<uint64_t>(i);
740 Visit(nullptr, &ui);
741 continue;
742 }
743 case MetadataKind::kInt64: {
744 int64_t i = Downcast<Integer>(o).IntValue();
745 Visit(nullptr, &i);
746 continue;
747 }
748 case MetadataKind::kBool: {
749 bool b = Downcast<Bool>(o);
750 Visit(nullptr, &b);
751 break;
752 }
753 case MetadataKind::kString: {
754 std::string s = Downcast<String>(o);
755 Visit(nullptr, &s);
756 break;
757 }
758 case MetadataKind::kHandle:
759 CHECK(false) << "Don't know how to serialize handle";
760 break;
761
762 case MetadataKind::kMetadata: {
763 runtime::metadata::MetadataBase metadata = Downcast<runtime::metadata::MetadataBase>(o);
764 std::stringstream i_str;
765 i_str << i;
766 address_.push_back(i_str.str());
767 Visit(nullptr, &metadata);
768 address_.pop_back();
769 break;
770 }
771 default:
772 CHECK(false) << "Unknown MetadataKind for array: " << array->kind;
773 break;
774 }
775 is_first_item_ = false;
776 }
777 is_first_item_ = old_is_first_item;
778 }
779
780 void Visit(const char* key, ObjectRef* value) final {
781 const runtime::metadata::MetadataArrayNode* arr =
782 value->as<runtime::metadata::MetadataArrayNode>();
783 if (arr != nullptr) {
784 WriteComma();
785 if (key != nullptr) {
786 address_.push_back(key);
787 }
788 code_ << metadata::AddressFromParts(address_);
789 if (key != nullptr) {
790 address_.pop_back();
791 }
792 return;
793 }
794
795 runtime::metadata::MetadataBase metadata = Downcast<runtime::metadata::MetadataBase>(*value);
796 if (key != nullptr) { // NOTE: outermost call passes nullptr key
797 address_.push_back(key);
798 }
799 WriteComma();
800 code_ << "{\n";
801 is_first_item_ = true;
802 ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this);
803 code_ << "}\n";
804 if (key != nullptr) { // NOTE: outermost call passes nullptr key
805 address_.pop_back();
806 }
807 }
808
809 private:
810 void EmitCType(const runtime::metadata::MetadataArrayNode* arr, std::ostream& os) {
811 switch (arr->kind) {
812 case MetadataKind::kUint64:
813 os << "uint64_t";
814 break;
815 case MetadataKind::kInt64:
816 os << "int64_t";
817 break;
818 case MetadataKind::kBool:
819 os << "bool";
820 break;
821 case MetadataKind::kString:
822 os << "const char*";
823 break;
824 case MetadataKind::kHandle:
825 os << "void*";
826 break;
827 case MetadataKind::kMetadata:
828 os << "struct " << arr->get_element_c_struct_name();
829 break;
830 default:
831 CHECK(false) << "Unknown kind in MetadataArray: " << arr->kind
832 << " (struct_name=" << arr->get_c_struct_name() << ")";
833 break;
834 }
835 }
836
837 public:
838 void CodegenMetadata(::tvm::runtime::metadata::Metadata metadata) {
839 decl_ << "#include <inttypes.h>" << std::endl
840 << "#include <tvm/runtime/metadata_types.h>" << std::endl
841 << "#include <tvm/runtime/c_runtime_api.h>" << std::endl;
842 std::vector<metadata::DiscoverArraysVisitor::DiscoveredArray> queue;
843 metadata::DiscoverArraysVisitor array_discover{&queue};
844 array_discover.Visit(metadata::kMetadataGlobalSymbol, &metadata);
845
846 for (auto item : queue) {
847 auto struct_address = std::get<0>(item);
848 address_.push_back(struct_address);
849
850 auto arr = std::get<1>(item);
851
852 // Prepend const with everything except C-string, which needs appending.
853 code_ << "static ";
854 if (arr->kind != MetadataKind::kString) {
855 code_ << "const ";
856 }
857 EmitCType(arr.operator->(), code_);
858 if (arr->kind == MetadataKind::kString) {
859 code_ << " const";
860 }
861 code_ << " " << struct_address << "[" << arr->array.size() << "] = {" << std::endl;
862 is_first_item_ = true;
863
864 VisitArray(arr);
865 address_.pop_back();
866 code_ << "};" << std::endl;
867 }
868
869 // Finally, emit overall struct.
870 address_.push_back(metadata::kMetadataGlobalSymbol);
871 code_ << "static const struct TVMMetadata " << metadata::AddressFromParts(address_) << "[1] = {"
872 << std::endl;
873 Visit(nullptr, &metadata);
874 code_ << "};" << std::endl;
875 address_.pop_back();
876 }
877
878 std::string GetOutput() { return decl_.str() + code_.str(); }
879
880 private:
881 std::vector<std::string> address_;
882 std::stringstream decl_;
883 std::stringstream code_;
884 bool is_first_item_;
885 std::unordered_set<std::string> generated_struct_decls_;
886 std::vector<bool> is_defining_struct_;
887};
888
889namespace {
890runtime::Module CreateAotMetadataModule(runtime::metadata::Metadata aot_metadata,
891 bool is_c_runtime) {
892 MetadataSerializer serializer;
893 serializer.CodegenMetadata(aot_metadata);
894 std::stringstream lookup_func;
895 std::string get_c_metadata_func_name;
896
897 // NOTE: mangling is not needed in the c++ runtime because the function
898 // name is looked-up via LibraryModule.
899 // TODO(alanmacd): unify these two approaches
900
901 if (is_c_runtime == true) {
902 get_c_metadata_func_name = runtime::get_name_mangled(
903 aot_metadata->mod_name(), ::tvm::runtime::symbol::tvm_get_c_metadata);
904 } else {
905 get_c_metadata_func_name = ::tvm::runtime::symbol::tvm_get_c_metadata;
906 }
907
908 lookup_func << "#ifdef __cplusplus\n"
909 << "extern \"C\"\n"
910 << "#endif\n";
911
912 lookup_func << "TVM_DLL int32_t " << get_c_metadata_func_name
913 << "(TVMValue* arg_values, int* arg_tcodes, int "
914 "num_args, TVMValue* ret_values, int* ret_tcodes, void* resource_handle) {"
915 << std::endl;
916 lookup_func << " ret_values[0].v_handle = (void*) &" << MetadataSerializer::kGlobalSymbol
917 << ";" << std::endl;
918 lookup_func << " ret_tcodes[0] = kTVMOpaqueHandle;" << std::endl;
919 lookup_func << " return 0;" << std::endl;
920 lookup_func << "};" << std::endl;
921 std::vector<String> func_names{get_c_metadata_func_name};
922 return CSourceModuleCreate(serializer.GetOutput() + lookup_func.str(), "c", func_names,
923 Array<String>());
924}
925} // namespace
926
927runtime::Module CreateCSourceCrtMetadataModule(const Array<runtime::Module>& modules, Target target,
928 relay::Runtime runtime,
929 relay::backend::ExecutorCodegenMetadata metadata,
930 runtime::metadata::Metadata aot_metadata) {
931 Array<runtime::Module> final_modules(modules);
932 Array<String> func_names;
933
934 if (metadata.defined()) {
935 if (metadata->executor == "aot") {
936 if (aot_metadata.defined()) {
937 final_modules.push_back(CreateAotMetadataModule(aot_metadata, true));
938 }
939
940 // add the run function (typically "tvmgen_default_run") to function registry
941 // when using AOT executor
942 std::string run_func = runtime::get_name_mangled(metadata->mod_name, "run");
943 func_names.push_back(run_func);
944 }
945 }
946
947 for (runtime::Module mod : final_modules) {
948 auto pf_funcs = mod.GetFunction("get_func_names");
949 if (pf_funcs != nullptr) {
950 Array<String> func_names_ = pf_funcs();
951 for (const auto& fname : func_names_) {
952 func_names.push_back(fname);
953 }
954 }
955 }
956
957 auto n = make_object<CSourceCrtMetadataModuleNode>(func_names, "c", target, runtime, metadata);
958 auto csrc_metadata_module = runtime::Module(n);
959 for (const auto& mod : final_modules) {
960 csrc_metadata_module.Import(mod);
961 }
962
963 return std::move(csrc_metadata_module);
964}
965
966runtime::Module CreateCSourceCppMetadataModule(runtime::metadata::Metadata metadata) {
967 MetadataSerializer serializer;
968 serializer.CodegenMetadata(metadata);
969 std::stringstream lookup_func;
970 lookup_func << "#ifdef __cplusplus\n"
971 << "extern \"C\"\n"
972 << "#endif\n";
973
974 lookup_func << "TVM_DLL int32_t " << ::tvm::runtime::symbol::tvm_get_c_metadata
975 << "(TVMValue* arg_values, int* arg_tcodes, int "
976 "num_args, TVMValue* ret_values, int* ret_tcodes, void* resource_handle) {"
977 << std::endl;
978 lookup_func << " ret_values[0].v_handle = (void*) &" << metadata::kMetadataGlobalSymbol << ";"
979 << std::endl;
980 lookup_func << " ret_tcodes[0] = kTVMOpaqueHandle;" << std::endl;
981 lookup_func << " return 0;" << std::endl;
982 lookup_func << "};" << std::endl;
983
984 auto mod = MetadataModuleCreate(metadata);
985 mod->Import(CreateAotMetadataModule(metadata, false));
986 return mod;
987}
988
989// supports limited save without cross compile
990class DeviceSourceModuleNode final : public runtime::ModuleNode {
991 public:
992 DeviceSourceModuleNode(std::string data, std::string fmt,
993 std::unordered_map<std::string, FunctionInfo> fmap, std::string type_key,
994 std::function<std::string(const std::string&)> fget_source)
995 : data_(data), fmt_(fmt), fmap_(fmap), type_key_(type_key), fget_source_(fget_source) {}
996
997 PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
998 LOG(FATAL) << "Source module cannot execute, to get executable module"
999 << " build TVM with \'" << fmt_ << "\' runtime support";
1000 return PackedFunc();
1001 }
1002
1003 std::string GetSource(const std::string& format) final {
1004 if (fget_source_ != nullptr) {
1005 return fget_source_(format);
1006 } else {
1007 return data_;
1008 }
1009 }
1010
1011 const char* type_key() const final { return type_key_.c_str(); }
1012
1013 void SaveToFile(const std::string& file_name, const std::string& format) final {
1014 std::string fmt = GetFileFormat(file_name, format);
1015 ICHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_;
1016 std::string meta_file = GetMetaFilePath(file_name);
1017 SaveMetaDataToFile(meta_file, fmap_);
1018 SaveBinaryToFile(file_name, data_);
1019 }
1020
1021 void SaveToBinary(dmlc::Stream* stream) final {
1022 stream->Write(fmt_);
1023 stream->Write(fmap_);
1024 stream->Write(data_);
1025 }
1026
1027 private:
1028 std::string data_;
1029 std::string fmt_;
1030 std::unordered_map<std::string, FunctionInfo> fmap_;
1031 std::string type_key_;
1032 std::function<std::string(const std::string&)> fget_source_;
1033};
1034
1035runtime::Module DeviceSourceModuleCreate(
1036 std::string data, std::string fmt, std::unordered_map<std::string, FunctionInfo> fmap,
1037 std::string type_key, std::function<std::string(const std::string&)> fget_source) {
1038 auto n = make_object<DeviceSourceModuleNode>(data, fmt, fmap, type_key, fget_source);
1039 return runtime::Module(n);
1040}
1041
1042TVM_REGISTER_GLOBAL("runtime.SourceModuleCreate").set_body_typed(SourceModuleCreate);
1043
1044TVM_REGISTER_GLOBAL("runtime.CSourceModuleCreate")
1045 .set_body_typed([](String code, String fmt, Array<String> func_names,
1046 Array<String> const_vars) {
1047 return CSourceModuleCreate(code, fmt, func_names, const_vars);
1048 });
1049
1050TVM_REGISTER_GLOBAL("runtime.CreateCSourceCrtMetadataModule")
1051 .set_body_typed([](const Array<runtime::Module>& modules, Target target,
1052 relay::Runtime runtime) {
1053 // Note that we don't need metadata when we compile a single operator
1054 return CreateCSourceCrtMetadataModule(modules, target, runtime,
1055 relay::backend::ExecutorCodegenMetadata(),
1056 runtime::metadata::Metadata());
1057 });
1058
1059} // namespace codegen
1060} // namespace tvm
1061