1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file codegen_c_host.cc
22 */
23#include "codegen_c_host.h"
24
25#include <tvm/relay/executor.h>
26#include <tvm/relay/runtime.h>
27#include <tvm/runtime/crt/error_codes.h>
28#include <tvm/runtime/module.h>
29#include <tvm/target/codegen.h>
30
31#include <algorithm>
32#include <string>
33#include <unordered_set>
34#include <utility>
35#include <vector>
36
37#include "../../support/str_escape.h"
38#include "../build_common.h"
39#include "../func_registry_generator.h"
40#include "codegen_params.h"
41
42namespace tvm {
43namespace codegen {
44
45CodeGenCHost::CodeGenCHost() { module_name_ = name_supply_->FreshName("__tvm_module_ctx"); }
46
47void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_decl,
48 std::string target_str, const std::unordered_set<std::string>& devices) {
49 emit_asserts_ = emit_asserts;
50 emit_fwd_func_decl_ = emit_fwd_func_decl;
51 declared_globals_.clear();
52 decl_stream << "// tvm target: " << target_str << "\n";
53 decl_stream << "#define TVM_EXPORTS\n";
54 decl_stream << "#include \"tvm/runtime/c_runtime_api.h\"\n";
55 decl_stream << "#include \"tvm/runtime/c_backend_api.h\"\n";
56 decl_stream << "#include <math.h>\n";
57 decl_stream << "#include <stdbool.h>\n";
58 if (devices.find("ethos-u") != devices.end()) {
59 decl_stream << "#include <tvm_ethosu_runtime.h>\n";
60 }
61 if (devices.find("cmsis-nn") != devices.end()) {
62 decl_stream << "#include <stdio.h>\n";
63 decl_stream << "#include <stdlib.h>\n";
64 decl_stream << "#include <dlpack/dlpack.h>\n";
65 decl_stream << "#include <arm_nnfunctions.h>\n";
66 decl_stream << "#include <arm_nn_types.h>\n";
67 decl_stream << "#include <arm_nn_math_types.h>\n";
68 }
69 CodeGenC::Init(output_ssa);
70}
71
72void CodeGenCHost::InitGlobalContext() {
73 decl_stream << "void* " << tvm::runtime::symbol::tvm_module_ctx << " = NULL;\n";
74}
75
76void CodeGenCHost::DefineModuleName() { decl_stream << "void* " << module_name_ << " = NULL;\n"; }
77
78void CodeGenCHost::AddFunction(const PrimFunc& f, bool emit_fwd_func_decl) {
79 auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
80 ICHECK(global_symbol.defined())
81 << "CodeGenCHost: Expect PrimFunc to have the global_symbol attribute";
82 function_names_.push_back(global_symbol.value());
83
84 emit_fwd_func_decl_ = emit_fwd_func_decl;
85 CodeGenC::AddFunction(f);
86 if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
87 function_names_.push_back(runtime::symbol::tvm_module_main);
88 stream << "// CodegenC: NOTE: Auto-generated entry function\n";
89 PrintFuncPrefix(stream);
90 stream << " " << tvm::runtime::symbol::tvm_module_main
91 << "(void* args, int* arg_type_ids, int num_args, void* out_ret_value, "
92 << "int* out_ret_tcode, void* resource_handle) {\n";
93 stream << " return " << global_symbol.value()
94 << "(args, arg_type_ids, num_args, out_ret_value, out_ret_tcode, resource_handle);\n";
95 stream << "}\n";
96 }
97}
98
99void CodeGenCHost::GenerateForwardFunctionDeclarations(String global_symbol,
100 const Array<PrimExpr>& args) {
101 if (!emit_fwd_func_decl_) {
102 return;
103 }
104 for (auto& func_already_defined : GetFunctionNames()) {
105 if (global_symbol == func_already_defined) {
106 return;
107 }
108 }
109 this->PrintFuncPrefix(fwd_decl_stream);
110 fwd_decl_stream << " " << global_symbol << "(";
111 for (size_t i = 1; i < args.size(); ++i) {
112 CodeGenSourceBase::PrintType(GetType(args[i]), fwd_decl_stream);
113 fwd_decl_stream << " ", this->PrintExpr(args[i], fwd_decl_stream);
114 if (i < args.size() - 1) {
115 fwd_decl_stream << ", ";
116 }
117 }
118 fwd_decl_stream << ");\n";
119}
120
121void CodeGenCHost::PrintFuncPrefix(std::ostream& os) { // NOLINT(*)
122 os << "#ifdef __cplusplus\n"
123 << "extern \"C\"\n"
124 << "#endif\n"
125 << "TVM_DLL int32_t";
126}
127
128void CodeGenCHost::PrintFinalReturn() { // NOLINT(*)
129 this->PrintIndent();
130 stream << "return 0;\n";
131}
132
133std::string CodeGenCHost::Finish() { // NOLINT(*)
134 std::string ret = decl_stream.str();
135 if (emit_fwd_func_decl_) {
136 ret += fwd_decl_stream.str();
137 }
138 ret += stream.str();
139 return ret;
140}
141
142void CodeGenCHost::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
143 int lanes = t.lanes();
144 if (t.is_handle()) {
145 ICHECK_EQ(lanes, 1) << "does not support vector types";
146 os << "void*";
147 return;
148 }
149 if (t.is_void()) {
150 os << "void";
151 return;
152 }
153 if (t == DataType::Bool()) {
154 os << "bool";
155 return;
156 }
157 bool fail = false;
158 if (t.is_float()) {
159 switch (t.bits()) {
160 case 16:
161 os << "half";
162 break;
163 case 32:
164 os << "float";
165 break;
166 case 64:
167 os << "double";
168 break;
169 default:
170 fail = true;
171 break;
172 }
173 if (!fail && lanes == 1) return;
174 if (!fail && (lanes >= 2 && lanes <= 16)) {
175 os << lanes;
176 return;
177 }
178 } else if (t.is_uint() || t.is_int()) {
179 if (t.is_uint()) {
180 os << 'u';
181 }
182 switch (t.bits()) {
183 case 8:
184 os << "int8_t";
185 break;
186 case 16:
187 os << "int16_t";
188 break;
189 case 32:
190 os << "int32_t";
191 break;
192 case 64:
193 os << "int64_t";
194 break;
195 case 1:
196 os << "int32_t";
197 break;
198 default:
199 fail = true;
200 break;
201 }
202 if (!fail && lanes == 1) return;
203 if (!fail && (lanes >= 2 && lanes <= 16)) {
204 os << lanes;
205 return;
206 }
207 }
208 LOG(FATAL) << "Cannot convert type " << t << " to C type";
209}
210
211void CodeGenCHost::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
212 std::string v = PrintExpr(op->value);
213 os << "((";
214 PrintType(op->dtype, os);
215 os << ")(";
216 for (int i = 0; i < op->lanes; ++i) {
217 if (i != 0) os << ", ";
218 os << v;
219 }
220 os << "))";
221}
222
223void CodeGenCHost::PrintGetFuncFromBackend(const std::string& func_name,
224 const std::string& packed_func_name) {
225 this->PrintIndent();
226 this->stream << "if (" << packed_func_name << " == NULL) {\n";
227 int packed_func_if_scope = this->BeginScope();
228 this->PrintIndent();
229 this->stream << "if (TVMBackendGetFuncFromEnv(" << module_name_ << ", \"" << func_name << "\""
230 << ", &" << packed_func_name << ") != 0) {\n";
231 int get_func_env_scope = this->BeginScope();
232 this->PrintIndent();
233 this->stream << "return -1;\n";
234 this->EndScope(get_func_env_scope);
235 this->PrintIndent();
236 this->stream << "}\n";
237 this->EndScope(packed_func_if_scope);
238 this->PrintIndent();
239 this->stream << "}\n";
240}
241
242void CodeGenCHost::PrintFuncCall(const std::string& packed_func_name, int num_args) {
243 this->PrintIndent();
244 std::string ret_val = name_supply_->FreshName("ret_val");
245 std::string ret_type_code = name_supply_->FreshName("ret_type_code");
246 this->stream << "TVMValue " << ret_val << ";\n";
247 this->PrintIndent();
248 this->stream << "int " << ret_type_code << ";\n";
249 this->PrintIndent();
250 this->stream << "if (TVMFuncCall(" << packed_func_name << ", "
251 << "(TVMValue*) stack_value"
252 << ", "
253 << "(int*) stack_tcode"
254 << ", " << num_args << ", "
255 << "&" << ret_val << ", "
256 << "&" << ret_type_code << ") != 0) {\n";
257 int func_call_scope = this->BeginScope();
258 this->PrintIndent();
259 this->stream << "return -1;\n";
260 this->EndScope(func_call_scope);
261 this->PrintIndent();
262 this->stream << "}\n";
263}
264
265void CodeGenCHost::PrintFuncCallC(const std::string& packed_func_name, int num_args,
266 const std::string& resource_handle_name) {
267 this->PrintIndent();
268 std::string ret_val = name_supply_->FreshName("ret_val");
269 std::string ret_type_code = name_supply_->FreshName("ret_type_code");
270 this->stream << "TVMValue " << ret_val << ";\n";
271 this->PrintIndent();
272 this->stream << "int " << ret_type_code << ";\n";
273 this->PrintIndent();
274
275 this->stream << "if (" << packed_func_name << "( "
276 << "(TVMValue*) stack_value "
277 << ", "
278 << "(int*) stack_tcode"
279 << ", " << num_args << ", "
280 << "&" << ret_val << ", "
281 << "&" << ret_type_code << ", " << resource_handle_name << ") != 0){\n";
282
283 int func_call_scope = this->BeginScope();
284 this->PrintIndent();
285 this->stream << "return -1;\n";
286 this->EndScope(func_call_scope);
287 this->PrintIndent();
288 this->stream << "}\n";
289}
290
291std::string CodeGenCHost::GetPackedName(const CallNode* op) {
292 const StringImmNode* s = op->args[0].as<StringImmNode>();
293 ICHECK(s != nullptr) << "tvm_call_packed_lowered expects first argument as function name";
294 std::string func_name = s->value;
295 std::string packed_func_name = func_name + "_packed";
296 std::string unique_name;
297 auto it = declared_globals_.find(packed_func_name);
298 if (it != declared_globals_.end()) {
299 unique_name = it->second;
300 } else {
301 unique_name = name_supply_->FreshName(packed_func_name);
302 declared_globals_[packed_func_name] = unique_name;
303 decl_stream << "static void* " << unique_name << " = NULL;\n";
304 }
305 return unique_name;
306}
307
308CodeGenCHost::FunctionInfo CodeGenCHost::GetFunctionInfo(const CallNode* op,
309 bool has_resource_handle) {
310 const StringImmNode* s = op->args[0].as<StringImmNode>();
311 ICHECK(s != nullptr) << "tvm_call_[c]packed_lowered expects first argument as function name";
312 int64_t begin = op->args[3].as<IntImmNode>()->value;
313 int64_t end = op->args[4].as<IntImmNode>()->value;
314 int64_t num_args = end - begin;
315 ICHECK_GE(num_args, 0);
316 std::string func_name = s->value;
317
318 if (has_resource_handle) {
319 const StringImmNode* resource_handle_var = op->args[5].as<StringImmNode>();
320 if (resource_handle_var != nullptr) {
321 std::string resource_handle_name = resource_handle_var->value;
322 return {func_name, num_args - 1, resource_handle_name};
323 } else {
324 // The final arg should be "(void*) NULL" to indicate the empty resource_handle.
325 num_args--;
326
327 const CallNode* reinterpret_call = op->args[5].as<CallNode>();
328 ICHECK_NE(reinterpret_call, (void*)nullptr)
329 << "At CallNode to " << s
330 << "arg 5: Expect either StringImm naming the resource_handle var from interface API or "
331 << "reinterpret(0); got: " << op->args[5];
332 ICHECK_EQ(reinterpret_call->op, builtin::reinterpret())
333 << "At CallNode to " << s
334 << "arg 5: Expect either StringImm naming the resource_handle var from interface API or "
335 << "reinterpret(0); got: " << op->args[5];
336 ICHECK(is_zero(reinterpret_call->args[0])) << "At CallNode to " << s
337 << " arg 5: Expect either StringImm naming the "
338 "resource_handle var from interface API, or "
339 << "zero; got " << op->args[5];
340 }
341 }
342 return {func_name, num_args, "NULL"};
343}
344
345void CodeGenCHost::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
346 if (op->op.same_as(builtin::tvm_stack_alloca())) {
347 std::string stack_name = name_supply_->FreshName("stack");
348 const std::string& type = op->args[0].as<StringImmNode>()->value;
349 const IntImmNode* num = op->args[1].as<IntImmNode>();
350 ICHECK(num != nullptr);
351 static_assert(alignof(TVMValue) % alignof(DLTensor) == 0, "invariant");
352 size_t unit = sizeof(TVMValue);
353 size_t size = 0;
354 if (type == "shape") {
355 size = (num->value * sizeof(tvm_index_t) + unit - 1) / unit;
356 } else if (type == "arg_value") {
357 size = (num->value * sizeof(TVMValue) + unit - 1) / unit;
358 } else if (type == "arg_tcode") {
359 size = (num->value * sizeof(int) + unit - 1) / unit;
360 } else if (type == "array") {
361 size = (num->value * sizeof(DLTensor) + unit - 1) / unit;
362 } else {
363 LOG(FATAL) << "Unknown stack alloca type " << type;
364 }
365 this->PrintIndent();
366 this->stream << "TVMValue " << stack_name << "[" << size << "];\n";
367 os << stack_name;
368 } else if (op->op.same_as(builtin::tvm_call_packed_lowered())) {
369 auto function_info = GetFunctionInfo(op, false /* has_resource_handle */);
370 std::string func_name_packed = GetPackedName(op);
371 this->PrintGetFuncFromBackend(function_info.func_name, func_name_packed);
372 this->PrintFuncCall(func_name_packed, function_info.num_args);
373 } else if (op->op.same_as(builtin::tvm_call_cpacked_lowered())) {
374 auto function_info = GetFunctionInfo(op, true /* has_resource_handle */);
375 this->PrintFuncCallC(function_info.func_name, function_info.num_args,
376 function_info.resource_handle_name);
377 } else if (op->op.same_as(builtin::tvm_throw_last_error())) {
378 this->PrintIndent();
379 this->stream << "return -1;\n";
380 } else {
381 CodeGenC::VisitExpr_(op, os);
382 }
383}
384
385void CodeGenCHost::VisitStmt_(const AssertStmtNode* op) { // NOLINT(*)
386 if (emit_asserts_) {
387 std::string cond = PrintExpr(op->condition);
388 PrintIndent();
389 stream << "if (!(" << cond << ")) {\n";
390 int assert_if_scope = this->BeginScope();
391 PrintIndent();
392 stream << "TVMAPISetLastError(\"" << op->message.as<StringImmNode>()->value << "\");\n";
393 PrintIndent();
394 stream << "return -1;\n";
395 this->EndScope(assert_if_scope);
396 PrintIndent();
397 stream << "}\n";
398 }
399 this->PrintStmt(op->body);
400}
401
402void CodeGenCHost::VisitExpr_(const MinNode* op, std::ostream& os) { // NOLINT(*)
403 PrintTernaryCondExpr(op, "<", os);
404}
405
406void CodeGenCHost::VisitExpr_(const MaxNode* op, std::ostream& os) { // NOLINT(*)
407 PrintTernaryCondExpr(op, ">", os);
408}
409
410template <typename T>
411inline void CodeGenCHost::PrintTernaryCondExpr(const T* op, const char* compare,
412 std::ostream& os) { // NOLINT(*)
413 std::ostringstream temp_a;
414 VisitExpr(op->a, temp_a);
415 std::string a_id = SSAGetID(temp_a.str(), op->a.dtype());
416 std::ostringstream temp_b;
417 VisitExpr(op->b, temp_b);
418 std::string b_id = SSAGetID(temp_b.str(), op->b.dtype());
419
420 os << "((" << a_id << ") " << compare << " (" << b_id << ") "
421 << "? (" << a_id << ") : (" << b_id << "))";
422}
423
424runtime::Module BuildCHost(IRModule mod, Target target) {
425 using tvm::runtime::Registry;
426 bool output_ssa = false;
427 bool emit_asserts = false;
428 bool emit_fwd_func_decl = true;
429
430 std::unordered_set<std::string> devices;
431 if (mod->GetAttr<Map<GlobalVar, String>>("device_contexts") != nullptr) {
432 Map<GlobalVar, String> device_contexts =
433 mod->GetAttr<Map<GlobalVar, String>>("device_contexts").value();
434 for (auto const& context : device_contexts) {
435 devices.insert(context.second.data());
436 }
437 }
438
439 CodeGenCHost cg;
440 cg.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), devices);
441 cg.SetConstantsByteAlignment(target->GetAttr<Integer>("constants-byte-alignment").value_or(16));
442 PrimFunc aot_executor_fn;
443
444 std::vector<std::pair<tvm::GlobalVar, tvm::BaseFunc>> funcs;
445 for (auto kv : mod->functions) {
446 // Make sure that the executor function is the last one to be code generated so that all the
447 // symbols are available to __tvm_main__
448 auto fun_name = std::string(kv.first->name_hint);
449 bool is_aot_executor_fn = kv.second->GetAttr<Bool>("runner_function", Bool(false)).value();
450
451 if (is_aot_executor_fn) {
452 aot_executor_fn = Downcast<PrimFunc>(kv.second);
453 continue;
454 }
455 funcs.push_back(kv);
456 }
457
458 // Sort functions
459 std::sort(funcs.begin(), funcs.end(),
460 [](std::pair<tvm::GlobalVar, tvm::BaseFunc> kv_a,
461 std::pair<tvm::GlobalVar, tvm::BaseFunc> kv_b) {
462 std::string name_hint_a = kv_a.first->name_hint;
463 std::string name_hint_b = kv_b.first->name_hint;
464 return name_hint_a < name_hint_b;
465 });
466
467 // Add all functions except __tvm_main__
468 for (auto& kv : funcs) {
469 ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodegenCHost: Can only take PrimFunc";
470 auto f = Downcast<PrimFunc>(kv.second);
471 cg.AddFunction(f);
472 }
473
474 // Add __tvm_main__
475 if (aot_executor_fn.defined()) {
476 emit_fwd_func_decl = true;
477 cg.AddFunction(aot_executor_fn, emit_fwd_func_decl);
478 }
479
480 // NOTE: it's possible that kRuntime attr is not attached when the mod was built with tvm.build().
481 // See issue #10373.
482 auto opt_runtime = mod->GetAttr<relay::Runtime>(tvm::attr::kRuntime);
483 relay::Runtime runtime;
484 if (opt_runtime.get() != nullptr) {
485 runtime = opt_runtime.value();
486 } else {
487 runtime = relay::Runtime::Create("cpp", {});
488 }
489 if (aot_executor_fn.defined() && runtime->name == relay::kTvmRuntimeCpp) {
490 cg.InitGlobalContext();
491 }
492
493 if (target->GetAttr<Bool>("system-lib").value_or(Bool(false))) {
494 ICHECK_EQ(target->GetAttr<String>("runtime").value_or(""), "c")
495 << "c target only supports generating C runtime SystemLibs";
496 }
497
498 std::string code = cg.Finish();
499 return CSourceModuleCreate(code, "c", cg.GetFunctionNames());
500}
501
502TVM_REGISTER_GLOBAL("target.build.c").set_body_typed(BuildCHost);
503} // namespace codegen
504} // namespace tvm
505