1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20#include <tvm/relay/transform.h>
21#include <tvm/relay/type.h>
22#include <tvm/runtime/module.h>
23#include <tvm/runtime/ndarray.h>
24#include <tvm/runtime/object.h>
25
26#include <sstream>
27#include <string>
28
29#include "../../../transforms/compiler_function_utils.h"
30#include "../../utils.h"
31#include "codegen_c.h"
32
33namespace tvm {
34namespace relay {
35namespace contrib {
36
37/*! \brief Return the "ccompiler" Target instance to use to guide compilation. */
38Target GetCCompilerTarget() {
39 Target target = Target::Current(/*allow_not_defined=*/true);
40 if (!target.defined() || target->kind->name != "ccompiler") {
41 // Use the default compilation options if no specific "ccompiler" target was given
42 // in the overall targets list. In that case target_hooks.cc will invoke the custom pass
43 // without pushing any target instance onto the implicit target stack.
44 target = Target("ccompiler");
45 }
46 return target;
47}
48
49/*!
50 * \brief Emits C/C++ code for a single function.
51 *
52 * For testing and demonstration only, only a few binary operators are supported.
53 */
54class CodegenC : public backend::MemoizedExprTranslator<std::vector<Output>>, public CodegenCBase {
55 public:
56 CodegenC(std::unordered_map<std::string, runtime::NDArray>* const_name_to_constant,
57 Array<String>* const_names, bool* needs_extra_headers, std::string ext_func_id)
58 : const_name_to_constant_(const_name_to_constant),
59 const_names_(const_names),
60 needs_extra_headers_(needs_extra_headers),
61 ext_func_id_(std::move(ext_func_id)) {}
62
63 /*!
64 * \brief Emit the source code that invokes C compiler compatible wrappers.
65 *
66 * \return The emitted code.
67 */
68 std::string JIT(const std::vector<Output>& out) override {
69 // Write function macros
70 for (auto decl : func_decl_) {
71 code_stream_ << decl << "\n";
72 }
73 return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body_, const_array_name_, out);
74 }
75
76 private:
77 std::vector<Output> VisitExprDefault_(const Object* op) override {
78 LOG(FATAL) << "C codegen doesn't support: " << op->GetTypeKey();
79 }
80
81 std::vector<Output> VisitExpr_(const VarNode* node) override {
82 ext_func_args_.push_back(GetRef<Var>(node));
83 Output output;
84 output.name = node->name_hint();
85 return {output};
86 }
87
88 std::vector<Output> VisitExpr_(const TupleNode* node) override {
89 std::vector<Output> outs;
90 for (auto field : node->fields) {
91 auto res = VisitExpr(field);
92 ICHECK_EQ(res.size(), 1U) << "Do not support tuple nest";
93 outs.push_back(res[0]);
94 }
95 return outs;
96 }
97
98 std::vector<Output> VisitExpr_(const TupleGetItemNode* op) override {
99 auto res = VisitExpr(op->tuple);
100 ICHECK_GT(res.size(), static_cast<size_t>(op->index));
101
102 // Only keep the item we want for the child node.
103 // FIXME(@comaniac): The other items should still be requried for the primary outputs.
104 return {res[op->index]};
105 }
106
107 std::vector<Output> VisitExpr_(const ConstantNode* cn) override {
108 // Remember we'll need some extra headers to support the runtime constants array.
109 *needs_extra_headers_ = true;
110
111 std::ostringstream decl_stream;
112 std::ostringstream buf_stream;
113
114 Output output;
115 // Get const: static_cast<float*>(gcc_0_consts[0]->data)
116 size_t const_id = const_name_to_constant_->size();
117 output.name = CreateDataReference(ext_func_id_, const_id);
118 const auto* type_node = cn->checked_type().as<TensorTypeNode>();
119 ICHECK(type_node);
120 const auto& dtype = GetDtypeString(type_node);
121
122 // Generate the global variable for needed ndarrays
123 if (const_array_name_.empty()) {
124 *needs_extra_headers_ = true;
125 const_array_name_ = CreateNDArrayPool(ext_func_id_);
126 std::string checker = CreateInitChecker(ext_func_id_);
127 ext_func_body_.insert(ext_func_body_.begin(), checker);
128 }
129
130 ICHECK(dtype == "float" || dtype == "int") << "Only float and int are supported for now.";
131 output.dtype = dtype;
132
133 std::string const_var_name = CreateConstVar(ext_func_id_, const_id);
134 const_name_to_constant_->emplace(const_var_name, cn->data);
135 const_names_->push_back(const_var_name);
136
137 return {output};
138 }
139
140 std::vector<Output> VisitExpr_(const CallNode* call) override {
141 std::ostringstream macro_stream;
142 std::ostringstream decl_stream;
143 std::ostringstream buf_stream;
144
145 std::string func_name = ext_func_id_ + "_" + std::to_string(func_idx++);
146
147 // Make function declaration
148 macro_stream << "CSOURCE_BINARY_OP_" << call->args.size() << "D(" << func_name << ", ";
149
150 if (backend::IsOp(call, "add")) {
151 macro_stream << "+";
152 } else if (backend::IsOp(call, "subtract")) {
153 macro_stream << "-";
154 } else if (backend::IsOp(call, "multiply")) {
155 macro_stream << "*";
156 } else {
157 LOG(FATAL) << "Unrecognized op";
158 }
159
160 auto in_shape = backend::GetShape(call->args[0]->checked_type());
161 for (size_t i = 0; i < in_shape.size(); ++i) {
162 macro_stream << ", " << in_shape[i];
163 }
164
165 const auto* type_node = call->checked_type().as<TensorTypeNode>();
166 ICHECK(type_node);
167 const auto& dtype = GetDtypeString(type_node);
168 macro_stream << ", " << dtype;
169
170 macro_stream << ");";
171 func_decl_.push_back(macro_stream.str());
172
173 // Make function call when visiting arguments
174 bool first = true;
175 decl_stream << func_name << "(";
176 for (size_t i = 0; i < call->args.size(); ++i) {
177 auto res = VisitExpr(call->args[i]);
178 for (auto out : res) {
179 if (!first) {
180 decl_stream << ", ";
181 }
182 first = false;
183 decl_stream << out.name;
184 }
185 }
186
187 std::string out = "buf_" + std::to_string(buf_idx_++);
188 auto out_shape = backend::GetShape(call->checked_type());
189 int out_size = 1;
190 for (size_t i = 0; i < out_shape.size(); ++i) {
191 out_size *= out_shape[i];
192 }
193 buf_stream << dtype << "* " << out << " = (" << dtype << "*)malloc(4 * " << out_size << ");";
194 buf_decl_.push_back(buf_stream.str());
195
196 decl_stream << ", " << out << ");";
197 ext_func_body_.push_back(decl_stream.str());
198
199 // Update output buffer
200 // Note C codegen only handles TensorType. Therefore, we don't flatten
201 // tuples and only return a single vaule.
202 Output output;
203 output.name = out;
204 output.dtype = dtype;
205 output.need_copy = true;
206 output.size = out_size;
207 return {output};
208 }
209
210 /*!
211 * \brief The accumulated constant name to constant mapping. Shared between all generated
212 * functions.
213 */
214 std::unordered_map<std::string, runtime::NDArray>* const_name_to_constant_;
215 /*! \brief The accumulated constant names, in the order they were generated. */
216 Array<String>* const_names_;
217 /*!
218 * \brief Set to true if the ndarray and packed function headers are required to declare and
219 * manage the constants array.
220 */
221 bool* needs_extra_headers_;
222 /*! \brief Name of the global function currently being compiled. */
223 std::string ext_func_id_;
224
225 /*! \brief The index of the next available wrapped C function. */
226 int func_idx = 0;
227 /*! \brief The index of the next available allocated buffers. */
228 int buf_idx_ = 0;
229 /*! \brief The arguments of a C compiler compatible function. */
230 Array<Var> ext_func_args_;
231 /*! \brief The statements of a C compiler compatible function. */
232 std::vector<std::string> ext_func_body_;
233 /*! \brief The array declared to store the constant values. */
234 std::string const_array_name_;
235 /*! \brief The declaration statements of a C compiler compatible function. */
236 std::vector<std::string> func_decl_;
237 /*! \brief The declaration statements of buffers. */
238 std::vector<std::string> buf_decl_;
239};
240
241/*! \brief Emits C/C++ code for a module. */
242class CodegenCModule {
243 public:
244 CodegenCModule(Target target, IRModule mod) : target_(std::move(target)), mod_(std::move(mod)) {}
245
246 runtime::Module CreateCSourceModule() {
247 for (const auto& kv : mod_->functions) {
248 if (const auto* function_node = GetCCompilerFunctionNode(kv.second)) {
249 GenCFunc(GetRef<Function>(function_node));
250 }
251 }
252 return Finalize();
253 }
254
255 /*! \brief Returns the accumulated constant name to constant mapping. */
256 const std::unordered_map<std::string, runtime::NDArray>& const_name_to_constant() const {
257 return const_name_to_constant_;
258 }
259
260 private:
261 /*! \brief Emits the standard C/C++ header into \p os. */
262 void EmitPreamble(std::ostringstream& os) {
263 // Custom header, if any.
264 Optional<String> header = target_->GetAttr<String>("header");
265 if (header.defined() && !header.value().empty()) {
266 os << header.value().c_str() << "\n";
267 }
268
269 // Standard includes.
270 os << "#include <stdio.h>\n";
271 os << "#include <stdlib.h>\n";
272 os << "#include <string.h>\n";
273 os << "#include <tvm/runtime/c_runtime_api.h>\n";
274 os << "#include <tvm/runtime/c_backend_api.h>\n";
275
276 if (needs_extra_headers_) {
277 // This segment would be generated in C++ because of the usage
278 // of tvm::runtime::Array. This is not ideal, but this to demonstrate
279 // constant copying process used packed imports in other external
280 // codegen. Moreover, in microTVM we dont expect this part to be generated.
281 os << "#ifdef __cplusplus\n";
282 os << "#include <tvm/runtime/ndarray.h>\n";
283 os << "#include <tvm/runtime/packed_func.h>\n";
284 os << "#endif\n";
285 }
286
287 // Define some macros to help operator implementations.
288 const char* operator_macro = R"op_macro(
289 #define CSOURCE_BINARY_OP_1D(p_ID_, p_OP_, p_DIM1_, p_DTYPE) \
290 void p_ID_(p_DTYPE* a, p_DTYPE* b, p_DTYPE* out) { \
291 for (int64_t i = 0; i < p_DIM1_; ++i) { \
292 out[i] = a[i] p_OP_ b[i]; \
293 } \
294 }
295
296 #define CSOURCE_BINARY_OP_2D(p_ID_, p_OP_, p_DIM1_, p_DIM2_, p_DTYPE) \
297 void p_ID_(p_DTYPE* a, p_DTYPE* b, p_DTYPE* out) { \
298 for (int64_t i = 0; i < p_DIM1_; ++i) { \
299 for (int64_t j = 0; j < p_DIM2_; ++j) { \
300 int64_t k = i * p_DIM2_ + j; \
301 out[k] = a[k] p_OP_ b[k]; \
302 } \
303 } \
304 }
305 )op_macro";
306
307 os << operator_macro << "\n\n";
308 }
309
310 void GenCFunc(const Function& function) {
311 ICHECK(function.defined()) << "Input error: expect a Relay function.";
312 std::string ext_func_id = backend::GetExtSymbol(function);
313 CodegenC builder(&const_name_to_constant_, &const_names_, &needs_extra_headers_, ext_func_id);
314 std::vector<Output> out = builder.VisitExpr(function->body);
315 code_stream_ << builder.JIT(out);
316 func_names_.push_back(ext_func_id);
317 }
318
319 /*! \brief Returns function if it is tagged with "Compiler=ccompiler". */
320 static const FunctionNode* GetCCompilerFunctionNode(const Expr& expr) {
321 if (const auto* function_node = expr.as<FunctionNode>()) {
322 Optional<String> opt_compiler = function_node->GetAttr<String>(attr::kCompiler);
323 if (opt_compiler.defined() && opt_compiler.value() == "ccompiler") {
324 return function_node;
325 }
326 }
327 return nullptr;
328 }
329
330 runtime::Module Finalize() {
331 std::ostringstream os;
332 EmitPreamble(os);
333 os << code_stream_.str();
334 std::string code = os.str();
335
336 VLOG(1) << "CodegenCModule generated:" << std::endl << code;
337
338 // Create a CSource module
339 const auto* pf = runtime::Registry::Get("runtime.CSourceModuleCreate");
340 ICHECK(pf != nullptr) << "Cannot find csource module to create the external runtime module";
341 return (*pf)(code, "c", func_names_, const_names_);
342 }
343
344 /*! \brief "ccompiler" Target with compilation options to use. */
345 Target target_;
346 /*! \brief Module we are compiling. */
347 IRModule mod_;
348
349 /*! \brief True if we need to include the ndarray and packed function headers. */
350 bool needs_extra_headers_ = false;
351 /*! \brief The accumulated constant name to constant mapping. */
352 std::unordered_map<std::string, runtime::NDArray> const_name_to_constant_;
353 /*! \brief The accumulated constant names, in the order they were generated. */
354 Array<String> const_names_;
355 /*! \brief The accumulated function names. */
356 Array<String> func_names_;
357 /*!
358 * \brief The accumulated code stream containing all function definitions.
359 * (Does not include the preamble.)
360 */
361 std::ostringstream code_stream_;
362};
363
364/*! \brief The actual translation pass. */
365tvm::transform::Pass CCompilerImpl() {
366 auto pass_func = [=](IRModule mod, const tvm::transform::PassContext& pass_ctx) {
367 VLOG(1) << "CCompilerImpl input:" << std::endl << PrettyPrint(mod);
368 Target target = GetCCompilerTarget();
369
370 // Emit the C/C++ code and package it as a CSourceModule.
371 CodegenCModule codegen(target, mod);
372 runtime::Module runtime_mod = codegen.CreateCSourceModule();
373
374 // Capture the new runtime module.
375 Array<runtime::Module> external_mods =
376 mod->GetAttr<Array<runtime::Module>>(tvm::attr::kExternalMods).value_or({});
377 external_mods.push_back(runtime_mod);
378
379 // Capture the new constants.
380 Map<String, runtime::NDArray> const_name_to_constant =
381 mod->GetAttr<Map<String, runtime::NDArray>>(tvm::attr::kConstNameToConstant).value_or({});
382 for (const auto& kv : codegen.const_name_to_constant()) {
383 ICHECK_EQ(const_name_to_constant.count(kv.first), 0);
384 const_name_to_constant.Set(kv.first, kv.second);
385 }
386
387 return WithAttrs(mod, {{tvm::attr::kExternalMods, external_mods},
388 {tvm::attr::kConstNameToConstant, const_name_to_constant}});
389 };
390 return tvm::transform::CreateModulePass(pass_func, 0, "CCompilerImpl", {});
391}
392
393tvm::transform::Pass CCompilerPass() {
394 return transform::Sequential(
395 {transform::OutlineCompilerFunctionsWithExistingGlobalSymbols("ccompiler"), CCompilerImpl(),
396 transform::MarkCompilerFunctionsAsExtern("ccompiler")});
397}
398
399} // namespace contrib
400} // namespace relay
401} // namespace tvm
402