1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file codegen_c_host.h |
22 | * \brief Generate C host code. |
23 | */ |
24 | #ifndef TVM_TARGET_SOURCE_CODEGEN_C_HOST_H_ |
25 | #define TVM_TARGET_SOURCE_CODEGEN_C_HOST_H_ |
26 | |
27 | #include <string> |
28 | #include <unordered_map> |
29 | #include <unordered_set> |
30 | #include <utility> |
31 | #include <vector> |
32 | |
33 | #include "codegen_c.h" |
34 | #include "tvm/target/codegen.h" |
35 | #include "tvm/tir/expr.h" |
36 | |
37 | namespace tvm { |
38 | namespace codegen { |
39 | |
40 | class CodeGenCHost : public CodeGenC { |
41 | public: |
42 | CodeGenCHost(); |
43 | void Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_decl, std::string target_str, |
44 | const std::unordered_set<std::string>& devices); |
45 | |
46 | void InitGlobalContext(); |
47 | void AddFunction(const PrimFunc& f, bool emit_fwd_func_decl = false); |
48 | std::string Finish() final; |
49 | /*! |
50 | * \brief Add functions from the (unordered) range to the current module in a deterministic |
51 | * order. This helps with debugging. |
52 | * |
53 | * \param functions A vector of unordered range of current module. |
54 | */ |
55 | void AddFunctionsOrdered(std::vector<std::pair<tvm::GlobalVar, tvm::BaseFunc>> functions); |
56 | void DefineModuleName(); |
57 | |
58 | void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) |
59 | void PrintFuncPrefix(std::ostream& os) final; // NOLINT(*) |
60 | void PrintFinalReturn() final; // NOLINT(*) |
61 | |
62 | // overload visitor functions |
63 | void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) |
64 | void VisitExpr_(const CallNode* op, std::ostream& os); // NOLINT(*) |
65 | // overload min and max to use the ternary operator, so we don't rely on the |
66 | // standard library implementations |
67 | void VisitExpr_(const MinNode* op, std::ostream& os) final; // NOLINT(*) |
68 | void VisitExpr_(const MaxNode* op, std::ostream& os) final; // NOLINT(*) |
69 | |
70 | void VisitStmt_(const AssertStmtNode* op) final; // NOLINT(*) |
71 | |
72 | virtual void GenerateForwardFunctionDeclarations(String global_symbol, |
73 | const Array<PrimExpr>& args); // NOLINT(*) |
74 | Array<String> GetFunctionNames() { return function_names_; } |
75 | |
76 | private: |
77 | /* \brief Internal structure to store information about function calls */ |
78 | struct FunctionInfo { |
79 | /* \brief function name */ |
80 | std::string func_name; |
81 | /* number of arguments required by the function */ |
82 | int64_t num_args; |
83 | /* \brief name of resource_handle to pass */ |
84 | std::string resource_handle_name; |
85 | }; |
86 | std::string module_name_; |
87 | /* \brief mapping global packed func to the unique name */ |
88 | std::unordered_map<std::string, std::string> declared_globals_; |
89 | /* \brief names of the functions declared in this module */ |
90 | Array<String> function_names_; |
91 | /*! \brief whether to emit asserts in the resulting C code */ |
92 | bool emit_asserts_; |
93 | /*! \brief whether to emit forwared function declarations in the resulting C code */ |
94 | bool emit_fwd_func_decl_; |
95 | |
96 | FunctionInfo GetFunctionInfo(const CallNode* op, bool has_resource_handle); |
97 | std::string GetPackedName(const CallNode* op); |
98 | void PrintGetFuncFromBackend(const std::string& func_name, const std::string& packed_func_name); |
99 | void PrintFuncCall(const std::string& packed_func_name, int num_args); |
100 | void PrintFuncCallC(const std::string& packed_func_name, int num_args, |
101 | const std::string& resource_handle_name); |
102 | |
103 | /*! |
104 | * \brief Print ternary conditional operator implementing binary `op` |
105 | * Forces the operands to be in SSA form. |
106 | * \param op binary operator being expressed |
107 | * \param compare string representation of comparison operator |
108 | * \param os stream reference to print into |
109 | */ |
110 | template <typename T> |
111 | inline void PrintTernaryCondExpr(const T* op, const char* compare, |
112 | std::ostream& os); // NOLINT(*) |
113 | }; |
114 | |
115 | } // namespace codegen |
116 | } // namespace tvm |
117 | |
118 | #endif // TVM_TARGET_SOURCE_CODEGEN_C_HOST_H_ |
119 | |