1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | #include <tvm/relay/transform.h> |
21 | #include <tvm/relay/type.h> |
22 | #include <tvm/runtime/module.h> |
23 | #include <tvm/runtime/ndarray.h> |
24 | #include <tvm/runtime/object.h> |
25 | |
26 | #include <sstream> |
27 | #include <string> |
28 | |
29 | #include "../../../transforms/compiler_function_utils.h" |
30 | #include "../../utils.h" |
31 | #include "codegen_c.h" |
32 | |
33 | namespace tvm { |
34 | namespace relay { |
35 | namespace contrib { |
36 | |
37 | /*! \brief Return the "ccompiler" Target instance to use to guide compilation. */ |
38 | Target GetCCompilerTarget() { |
39 | Target target = Target::Current(/*allow_not_defined=*/true); |
40 | if (!target.defined() || target->kind->name != "ccompiler" ) { |
41 | // Use the default compilation options if no specific "ccompiler" target was given |
42 | // in the overall targets list. In that case target_hooks.cc will invoke the custom pass |
43 | // without pushing any target instance onto the implicit target stack. |
44 | target = Target("ccompiler" ); |
45 | } |
46 | return target; |
47 | } |
48 | |
49 | /*! |
50 | * \brief Emits C/C++ code for a single function. |
51 | * |
52 | * For testing and demonstration only, only a few binary operators are supported. |
53 | */ |
54 | class CodegenC : public backend::MemoizedExprTranslator<std::vector<Output>>, public CodegenCBase { |
55 | public: |
56 | CodegenC(std::unordered_map<std::string, runtime::NDArray>* const_name_to_constant, |
57 | Array<String>* const_names, bool* , std::string ext_func_id) |
58 | : const_name_to_constant_(const_name_to_constant), |
59 | const_names_(const_names), |
60 | needs_extra_headers_(needs_extra_headers), |
61 | ext_func_id_(std::move(ext_func_id)) {} |
62 | |
63 | /*! |
64 | * \brief Emit the source code that invokes C compiler compatible wrappers. |
65 | * |
66 | * \return The emitted code. |
67 | */ |
68 | std::string JIT(const std::vector<Output>& out) override { |
69 | // Write function macros |
70 | for (auto decl : func_decl_) { |
71 | code_stream_ << decl << "\n" ; |
72 | } |
73 | return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body_, const_array_name_, out); |
74 | } |
75 | |
76 | private: |
77 | std::vector<Output> VisitExprDefault_(const Object* op) override { |
78 | LOG(FATAL) << "C codegen doesn't support: " << op->GetTypeKey(); |
79 | } |
80 | |
81 | std::vector<Output> VisitExpr_(const VarNode* node) override { |
82 | ext_func_args_.push_back(GetRef<Var>(node)); |
83 | Output output; |
84 | output.name = node->name_hint(); |
85 | return {output}; |
86 | } |
87 | |
88 | std::vector<Output> VisitExpr_(const TupleNode* node) override { |
89 | std::vector<Output> outs; |
90 | for (auto field : node->fields) { |
91 | auto res = VisitExpr(field); |
92 | ICHECK_EQ(res.size(), 1U) << "Do not support tuple nest" ; |
93 | outs.push_back(res[0]); |
94 | } |
95 | return outs; |
96 | } |
97 | |
98 | std::vector<Output> VisitExpr_(const TupleGetItemNode* op) override { |
99 | auto res = VisitExpr(op->tuple); |
100 | ICHECK_GT(res.size(), static_cast<size_t>(op->index)); |
101 | |
102 | // Only keep the item we want for the child node. |
103 | // FIXME(@comaniac): The other items should still be requried for the primary outputs. |
104 | return {res[op->index]}; |
105 | } |
106 | |
107 | std::vector<Output> VisitExpr_(const ConstantNode* cn) override { |
108 | // Remember we'll need some extra headers to support the runtime constants array. |
109 | *needs_extra_headers_ = true; |
110 | |
111 | std::ostringstream decl_stream; |
112 | std::ostringstream buf_stream; |
113 | |
114 | Output output; |
115 | // Get const: static_cast<float*>(gcc_0_consts[0]->data) |
116 | size_t const_id = const_name_to_constant_->size(); |
117 | output.name = CreateDataReference(ext_func_id_, const_id); |
118 | const auto* type_node = cn->checked_type().as<TensorTypeNode>(); |
119 | ICHECK(type_node); |
120 | const auto& dtype = GetDtypeString(type_node); |
121 | |
122 | // Generate the global variable for needed ndarrays |
123 | if (const_array_name_.empty()) { |
124 | *needs_extra_headers_ = true; |
125 | const_array_name_ = CreateNDArrayPool(ext_func_id_); |
126 | std::string checker = CreateInitChecker(ext_func_id_); |
127 | ext_func_body_.insert(ext_func_body_.begin(), checker); |
128 | } |
129 | |
130 | ICHECK(dtype == "float" || dtype == "int" ) << "Only float and int are supported for now." ; |
131 | output.dtype = dtype; |
132 | |
133 | std::string const_var_name = CreateConstVar(ext_func_id_, const_id); |
134 | const_name_to_constant_->emplace(const_var_name, cn->data); |
135 | const_names_->push_back(const_var_name); |
136 | |
137 | return {output}; |
138 | } |
139 | |
140 | std::vector<Output> VisitExpr_(const CallNode* call) override { |
141 | std::ostringstream macro_stream; |
142 | std::ostringstream decl_stream; |
143 | std::ostringstream buf_stream; |
144 | |
145 | std::string func_name = ext_func_id_ + "_" + std::to_string(func_idx++); |
146 | |
147 | // Make function declaration |
148 | macro_stream << "CSOURCE_BINARY_OP_" << call->args.size() << "D(" << func_name << ", " ; |
149 | |
150 | if (backend::IsOp(call, "add" )) { |
151 | macro_stream << "+" ; |
152 | } else if (backend::IsOp(call, "subtract" )) { |
153 | macro_stream << "-" ; |
154 | } else if (backend::IsOp(call, "multiply" )) { |
155 | macro_stream << "*" ; |
156 | } else { |
157 | LOG(FATAL) << "Unrecognized op" ; |
158 | } |
159 | |
160 | auto in_shape = backend::GetShape(call->args[0]->checked_type()); |
161 | for (size_t i = 0; i < in_shape.size(); ++i) { |
162 | macro_stream << ", " << in_shape[i]; |
163 | } |
164 | |
165 | const auto* type_node = call->checked_type().as<TensorTypeNode>(); |
166 | ICHECK(type_node); |
167 | const auto& dtype = GetDtypeString(type_node); |
168 | macro_stream << ", " << dtype; |
169 | |
170 | macro_stream << ");" ; |
171 | func_decl_.push_back(macro_stream.str()); |
172 | |
173 | // Make function call when visiting arguments |
174 | bool first = true; |
175 | decl_stream << func_name << "(" ; |
176 | for (size_t i = 0; i < call->args.size(); ++i) { |
177 | auto res = VisitExpr(call->args[i]); |
178 | for (auto out : res) { |
179 | if (!first) { |
180 | decl_stream << ", " ; |
181 | } |
182 | first = false; |
183 | decl_stream << out.name; |
184 | } |
185 | } |
186 | |
187 | std::string out = "buf_" + std::to_string(buf_idx_++); |
188 | auto out_shape = backend::GetShape(call->checked_type()); |
189 | int out_size = 1; |
190 | for (size_t i = 0; i < out_shape.size(); ++i) { |
191 | out_size *= out_shape[i]; |
192 | } |
193 | buf_stream << dtype << "* " << out << " = (" << dtype << "*)malloc(4 * " << out_size << ");" ; |
194 | buf_decl_.push_back(buf_stream.str()); |
195 | |
196 | decl_stream << ", " << out << ");" ; |
197 | ext_func_body_.push_back(decl_stream.str()); |
198 | |
199 | // Update output buffer |
200 | // Note C codegen only handles TensorType. Therefore, we don't flatten |
201 | // tuples and only return a single vaule. |
202 | Output output; |
203 | output.name = out; |
204 | output.dtype = dtype; |
205 | output.need_copy = true; |
206 | output.size = out_size; |
207 | return {output}; |
208 | } |
209 | |
210 | /*! |
211 | * \brief The accumulated constant name to constant mapping. Shared between all generated |
212 | * functions. |
213 | */ |
214 | std::unordered_map<std::string, runtime::NDArray>* const_name_to_constant_; |
215 | /*! \brief The accumulated constant names, in the order they were generated. */ |
216 | Array<String>* const_names_; |
217 | /*! |
218 | * \brief Set to true if the ndarray and packed function headers are required to declare and |
219 | * manage the constants array. |
220 | */ |
221 | bool* ; |
222 | /*! \brief Name of the global function currently being compiled. */ |
223 | std::string ext_func_id_; |
224 | |
225 | /*! \brief The index of the next available wrapped C function. */ |
226 | int func_idx = 0; |
227 | /*! \brief The index of the next available allocated buffers. */ |
228 | int buf_idx_ = 0; |
229 | /*! \brief The arguments of a C compiler compatible function. */ |
230 | Array<Var> ext_func_args_; |
231 | /*! \brief The statements of a C compiler compatible function. */ |
232 | std::vector<std::string> ext_func_body_; |
233 | /*! \brief The array declared to store the constant values. */ |
234 | std::string const_array_name_; |
235 | /*! \brief The declaration statements of a C compiler compatible function. */ |
236 | std::vector<std::string> func_decl_; |
237 | /*! \brief The declaration statements of buffers. */ |
238 | std::vector<std::string> buf_decl_; |
239 | }; |
240 | |
241 | /*! \brief Emits C/C++ code for a module. */ |
242 | class CodegenCModule { |
243 | public: |
244 | CodegenCModule(Target target, IRModule mod) : target_(std::move(target)), mod_(std::move(mod)) {} |
245 | |
246 | runtime::Module CreateCSourceModule() { |
247 | for (const auto& kv : mod_->functions) { |
248 | if (const auto* function_node = GetCCompilerFunctionNode(kv.second)) { |
249 | GenCFunc(GetRef<Function>(function_node)); |
250 | } |
251 | } |
252 | return Finalize(); |
253 | } |
254 | |
255 | /*! \brief Returns the accumulated constant name to constant mapping. */ |
256 | const std::unordered_map<std::string, runtime::NDArray>& const_name_to_constant() const { |
257 | return const_name_to_constant_; |
258 | } |
259 | |
260 | private: |
261 | /*! \brief Emits the standard C/C++ header into \p os. */ |
262 | void EmitPreamble(std::ostringstream& os) { |
263 | // Custom header, if any. |
264 | Optional<String> = target_->GetAttr<String>("header" ); |
265 | if (header.defined() && !header.value().empty()) { |
266 | os << header.value().c_str() << "\n" ; |
267 | } |
268 | |
269 | // Standard includes. |
270 | os << "#include <stdio.h>\n" ; |
271 | os << "#include <stdlib.h>\n" ; |
272 | os << "#include <string.h>\n" ; |
273 | os << "#include <tvm/runtime/c_runtime_api.h>\n" ; |
274 | os << "#include <tvm/runtime/c_backend_api.h>\n" ; |
275 | |
276 | if (needs_extra_headers_) { |
277 | // This segment would be generated in C++ because of the usage |
278 | // of tvm::runtime::Array. This is not ideal, but this to demonstrate |
279 | // constant copying process used packed imports in other external |
280 | // codegen. Moreover, in microTVM we dont expect this part to be generated. |
281 | os << "#ifdef __cplusplus\n" ; |
282 | os << "#include <tvm/runtime/ndarray.h>\n" ; |
283 | os << "#include <tvm/runtime/packed_func.h>\n" ; |
284 | os << "#endif\n" ; |
285 | } |
286 | |
287 | // Define some macros to help operator implementations. |
288 | const char* operator_macro = R"op_macro( |
289 | #define CSOURCE_BINARY_OP_1D(p_ID_, p_OP_, p_DIM1_, p_DTYPE) \ |
290 | void p_ID_(p_DTYPE* a, p_DTYPE* b, p_DTYPE* out) { \ |
291 | for (int64_t i = 0; i < p_DIM1_; ++i) { \ |
292 | out[i] = a[i] p_OP_ b[i]; \ |
293 | } \ |
294 | } |
295 | |
296 | #define CSOURCE_BINARY_OP_2D(p_ID_, p_OP_, p_DIM1_, p_DIM2_, p_DTYPE) \ |
297 | void p_ID_(p_DTYPE* a, p_DTYPE* b, p_DTYPE* out) { \ |
298 | for (int64_t i = 0; i < p_DIM1_; ++i) { \ |
299 | for (int64_t j = 0; j < p_DIM2_; ++j) { \ |
300 | int64_t k = i * p_DIM2_ + j; \ |
301 | out[k] = a[k] p_OP_ b[k]; \ |
302 | } \ |
303 | } \ |
304 | } |
305 | )op_macro" ; |
306 | |
307 | os << operator_macro << "\n\n" ; |
308 | } |
309 | |
310 | void GenCFunc(const Function& function) { |
311 | ICHECK(function.defined()) << "Input error: expect a Relay function." ; |
312 | std::string ext_func_id = backend::GetExtSymbol(function); |
313 | CodegenC builder(&const_name_to_constant_, &const_names_, &needs_extra_headers_, ext_func_id); |
314 | std::vector<Output> out = builder.VisitExpr(function->body); |
315 | code_stream_ << builder.JIT(out); |
316 | func_names_.push_back(ext_func_id); |
317 | } |
318 | |
319 | /*! \brief Returns function if it is tagged with "Compiler=ccompiler". */ |
320 | static const FunctionNode* GetCCompilerFunctionNode(const Expr& expr) { |
321 | if (const auto* function_node = expr.as<FunctionNode>()) { |
322 | Optional<String> opt_compiler = function_node->GetAttr<String>(attr::kCompiler); |
323 | if (opt_compiler.defined() && opt_compiler.value() == "ccompiler" ) { |
324 | return function_node; |
325 | } |
326 | } |
327 | return nullptr; |
328 | } |
329 | |
330 | runtime::Module Finalize() { |
331 | std::ostringstream os; |
332 | EmitPreamble(os); |
333 | os << code_stream_.str(); |
334 | std::string code = os.str(); |
335 | |
336 | VLOG(1) << "CodegenCModule generated:" << std::endl << code; |
337 | |
338 | // Create a CSource module |
339 | const auto* pf = runtime::Registry::Get("runtime.CSourceModuleCreate" ); |
340 | ICHECK(pf != nullptr) << "Cannot find csource module to create the external runtime module" ; |
341 | return (*pf)(code, "c" , func_names_, const_names_); |
342 | } |
343 | |
344 | /*! \brief "ccompiler" Target with compilation options to use. */ |
345 | Target target_; |
346 | /*! \brief Module we are compiling. */ |
347 | IRModule mod_; |
348 | |
349 | /*! \brief True if we need to include the ndarray and packed function headers. */ |
350 | bool = false; |
351 | /*! \brief The accumulated constant name to constant mapping. */ |
352 | std::unordered_map<std::string, runtime::NDArray> const_name_to_constant_; |
353 | /*! \brief The accumulated constant names, in the order they were generated. */ |
354 | Array<String> const_names_; |
355 | /*! \brief The accumulated function names. */ |
356 | Array<String> func_names_; |
357 | /*! |
358 | * \brief The accumulated code stream containing all function definitions. |
359 | * (Does not include the preamble.) |
360 | */ |
361 | std::ostringstream code_stream_; |
362 | }; |
363 | |
364 | /*! \brief The actual translation pass. */ |
365 | tvm::transform::Pass CCompilerImpl() { |
366 | auto pass_func = [=](IRModule mod, const tvm::transform::PassContext& pass_ctx) { |
367 | VLOG(1) << "CCompilerImpl input:" << std::endl << PrettyPrint(mod); |
368 | Target target = GetCCompilerTarget(); |
369 | |
370 | // Emit the C/C++ code and package it as a CSourceModule. |
371 | CodegenCModule codegen(target, mod); |
372 | runtime::Module runtime_mod = codegen.CreateCSourceModule(); |
373 | |
374 | // Capture the new runtime module. |
375 | Array<runtime::Module> external_mods = |
376 | mod->GetAttr<Array<runtime::Module>>(tvm::attr::kExternalMods).value_or({}); |
377 | external_mods.push_back(runtime_mod); |
378 | |
379 | // Capture the new constants. |
380 | Map<String, runtime::NDArray> const_name_to_constant = |
381 | mod->GetAttr<Map<String, runtime::NDArray>>(tvm::attr::kConstNameToConstant).value_or({}); |
382 | for (const auto& kv : codegen.const_name_to_constant()) { |
383 | ICHECK_EQ(const_name_to_constant.count(kv.first), 0); |
384 | const_name_to_constant.Set(kv.first, kv.second); |
385 | } |
386 | |
387 | return WithAttrs(mod, {{tvm::attr::kExternalMods, external_mods}, |
388 | {tvm::attr::kConstNameToConstant, const_name_to_constant}}); |
389 | }; |
390 | return tvm::transform::CreateModulePass(pass_func, 0, "CCompilerImpl" , {}); |
391 | } |
392 | |
393 | tvm::transform::Pass CCompilerPass() { |
394 | return transform::Sequential( |
395 | {transform::OutlineCompilerFunctionsWithExistingGlobalSymbols("ccompiler" ), CCompilerImpl(), |
396 | transform::MarkCompilerFunctionsAsExtern("ccompiler" )}); |
397 | } |
398 | |
399 | } // namespace contrib |
400 | } // namespace relay |
401 | } // namespace tvm |
402 | |