1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file codegen_c.h |
22 | * \brief Common utilities to generated C style code. |
23 | */ |
24 | #ifndef TVM_TARGET_SOURCE_CODEGEN_C_H_ |
25 | #define TVM_TARGET_SOURCE_CODEGEN_C_H_ |
26 | |
27 | #include <tvm/ir/op.h> |
28 | #include <tvm/target/codegen.h> |
29 | #include <tvm/tir/analysis.h> |
30 | #include <tvm/tir/builtin.h> |
31 | #include <tvm/tir/expr.h> |
32 | #include <tvm/tir/function.h> |
33 | #include <tvm/tir/op_attr_types.h> |
34 | #include <tvm/tir/stmt.h> |
35 | #include <tvm/tir/stmt_functor.h> |
36 | |
37 | #include <string> |
38 | #include <unordered_map> |
39 | #include <unordered_set> |
40 | #include <vector> |
41 | |
42 | #include "../../tir/transforms/ir_utils.h" |
43 | #include "codegen_source_base.h" |
44 | |
45 | namespace tvm { |
46 | namespace codegen { |
47 | |
48 | using namespace tir; |
49 | /*! |
50 | * \brief A base class to generate C code. |
51 | * |
52 | * CodeGenC have two modes: generate SSA formed C code or normal form. |
53 | * |
54 | * **NOTE** CodeGenC does not aim at generating C codes consumed by MSVC or GCC, |
55 | * Rather, it's providing infrastructural abstraction for C variants like CUDA |
56 | * and OpenCL-C. You might find some odd variant features, e.g., type `int3` for |
57 | * a vector of 3 `int`s. For native C code generator, see `CodeGenLLVM`. |
58 | */ |
59 | class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>, |
60 | public StmtFunctor<void(const Stmt&)>, |
61 | public CodeGenSourceBase { |
62 | public: |
63 | /*! |
64 | * \brief Initialize the code generator. |
65 | * \param output_ssa Whether output SSA. |
66 | */ |
67 | void Init(bool output_ssa); |
68 | /*! |
69 | * \brief Add the function to the generated module. |
70 | * \param f The function to be compiled. |
71 | * \param whether to append return 0 in the end. |
72 | */ |
73 | void AddFunction(const PrimFunc& f); |
74 | /*! |
75 | * \brief Finalize the compilation and return the code. |
76 | * \return The code. |
77 | */ |
78 | virtual std::string Finish(); |
79 | /*! |
80 | * \brief Print the Stmt n to CodeGenC->stream |
81 | * \param n The statement to be printed. |
82 | */ |
83 | void PrintStmt(const Stmt& n) { VisitStmt(n); } |
84 | /*! |
85 | * \brief Print the expression n(or its ssa id if in ssa mode) into os |
86 | * \param n The expression to be printed. |
87 | * \param os The output stream |
88 | */ |
89 | void PrintExpr(const PrimExpr& n, std::ostream& os); |
90 | /*! |
91 | * \brief Same as PrintExpr, but simply returns result string |
92 | * \param n The expression to be printed. |
93 | */ |
94 | std::string PrintExpr(const PrimExpr& n) { |
95 | std::ostringstream os; |
96 | PrintExpr(n, os); |
97 | return os.str(); |
98 | } |
99 | // The following parts are overloadable print operations. |
100 | /*! |
101 | * \brief Print the function header before the argument list |
102 | * \param os The output stream |
103 | * |
104 | * Example: stream << "void"; |
105 | */ |
106 | virtual void PrintFuncPrefix(std::ostream& os); // NOLINT(*) |
107 | /*! |
108 | * \brief Print extra function attributes |
109 | * |
110 | * Example: __launch_bounds__(256) for CUDA functions |
111 | */ |
112 | virtual void (const PrimFunc& f); |
113 | /*! |
114 | * \brief Print the final return at the end the function. |
115 | */ |
116 | virtual void PrintFinalReturn(); // NOLINT(*) |
117 | /*! |
118 | * \brief Insert statement before function body. |
119 | * \param f The function to be compiled. |
120 | */ |
121 | virtual void PreFunctionBody(const PrimFunc& f) {} |
122 | /*! |
123 | * \brief Initialize codegen state for generating f. |
124 | * \param f The function to be compiled. |
125 | */ |
126 | virtual void InitFuncState(const PrimFunc& f); |
127 | // expression |
128 | void VisitExpr_(const VarNode* op, std::ostream& os) override; // NOLINT(*) |
129 | void VisitExpr_(const LoadNode* op, std::ostream& os) override; // NOLINT(*) |
130 | void VisitExpr_(const BufferLoadNode* op, std::ostream& os) override; // NOLINT(*) |
131 | void VisitExpr_(const LetNode* op, std::ostream& os) override; // NOLINT(*) |
132 | void VisitExpr_(const CallNode* op, std::ostream& os) override; // NOLINT(*) |
133 | void VisitExpr_(const AddNode* op, std::ostream& os) override; // NOLINT(*) |
134 | void VisitExpr_(const SubNode* op, std::ostream& os) override; // NOLINT(*) |
135 | void VisitExpr_(const MulNode* op, std::ostream& os) override; // NOLINT(*) |
136 | void VisitExpr_(const DivNode* op, std::ostream& os) override; // NOLINT(*) |
137 | void VisitExpr_(const ModNode* op, std::ostream& os) override; // NOLINT(*) |
138 | void VisitExpr_(const MinNode* op, std::ostream& os) override; // NOLINT(*) |
139 | void VisitExpr_(const MaxNode* op, std::ostream& os) override; // NOLINT(*) |
140 | void VisitExpr_(const EQNode* op, std::ostream& os) override; // NOLINT(*) |
141 | void VisitExpr_(const NENode* op, std::ostream& os) override; // NOLINT(*) |
142 | void VisitExpr_(const LTNode* op, std::ostream& os) override; // NOLINT(*) |
143 | void VisitExpr_(const LENode* op, std::ostream& os) override; // NOLINT(*) |
144 | void VisitExpr_(const GTNode* op, std::ostream& os) override; // NOLINT(*) |
145 | void VisitExpr_(const GENode* op, std::ostream& os) override; // NOLINT(*) |
146 | void VisitExpr_(const AndNode* op, std::ostream& os) override; // NOLINT(*) |
147 | void VisitExpr_(const OrNode* op, std::ostream& os) override; // NOLINT(*) |
148 | void VisitExpr_(const CastNode* op, std::ostream& os) override; // NOLINT(*) |
149 | void VisitExpr_(const NotNode* op, std::ostream& os) override; // NOLINT(*) |
150 | void VisitExpr_(const SelectNode* op, std::ostream& os) override; // NOLINT(*) |
151 | void VisitExpr_(const RampNode* op, std::ostream& os) override; // NOLINT(*) |
152 | void VisitExpr_(const ShuffleNode* op, std::ostream& os) override; // NOLINT(*) |
153 | void VisitExpr_(const BroadcastNode* op, std::ostream& os) override; // NOLINT(*) |
154 | void VisitExpr_(const IntImmNode* op, std::ostream& os) override; // NOLINT(*) |
155 | void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*) |
156 | void VisitExpr_(const StringImmNode* op, std::ostream& os) override; // NOLINT(*) |
157 | // statment |
158 | void VisitStmt_(const LetStmtNode* op) override; |
159 | void VisitStmt_(const StoreNode* op) override; |
160 | void VisitStmt_(const BufferStoreNode* op) override; |
161 | void VisitStmt_(const ForNode* op) override; |
162 | void VisitStmt_(const WhileNode* op) override; |
163 | void VisitStmt_(const IfThenElseNode* op) override; |
164 | void VisitStmt_(const AllocateNode* op) override; |
165 | void VisitStmt_(const AttrStmtNode* op) override; |
166 | void VisitStmt_(const AssertStmtNode* op) override; |
167 | void VisitStmt_(const EvaluateNode* op) override; |
168 | void VisitStmt_(const SeqStmtNode* op) override; |
169 | void VisitStmt_(const AllocateConstNode* op) override; |
170 | void VisitStmt_(const DeclBufferNode* op) override; |
171 | |
172 | /*! |
173 | * \brief Print expr representing the thread tag |
174 | * \param IterVar iv The thread index to be binded; |
175 | */ |
176 | virtual void BindThreadIndex(const IterVar& iv); // NOLINT(*) |
177 | virtual void PrintStorageScope(const std::string& scope, std::ostream& os); // NOLINT(*) |
178 | virtual void PrintStorageSync(const CallNode* op); // NOLINT(*) |
179 | // Binary vector op. |
180 | virtual void PrintVecBinaryOp(const std::string& op, DataType op_type, PrimExpr lhs, PrimExpr rhs, |
181 | std::ostream& os); // NOLINT(*) |
182 | // print vector load |
183 | virtual std::string GetVecLoad(DataType t, const BufferNode* buffer, PrimExpr base); |
184 | // print vector store |
185 | virtual void PrintVecStore(const BufferNode* buffer, DataType t, PrimExpr base, |
186 | const std::string& value); // NOLINT(*) |
187 | // print load of single element |
188 | virtual void PrintVecElemLoad(const std::string& vec, DataType t, int i, |
189 | std::ostream& os); // NOLINT(*) |
190 | // print store of single element. |
191 | virtual void PrintVecElemStore(const std::string& vec, DataType t, int i, |
192 | const std::string& value); |
193 | // Get a cast type from to |
194 | virtual std::string CastFromTo(std::string value, DataType from, DataType target); |
195 | // Get load of single element with expression |
196 | virtual void PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os); |
197 | // Print restrict keyword for a given Var if applicable |
198 | virtual void PrintRestrict(const Var& v, std::ostream& os); |
199 | |
200 | virtual void SetConstantsByteAlignment(Integer constants_byte_alignment) { |
201 | constants_byte_alignment_ = constants_byte_alignment; |
202 | } |
203 | |
204 | protected: |
205 | // Print reference to struct location |
206 | std::string GetStructRef(DataType t, const PrimExpr& buffer, const PrimExpr& index, int kind); |
207 | // Print reference to a buffer as type t in index. |
208 | virtual std::string GetBufferRef(DataType t, const BufferNode* buffer, PrimExpr index); |
209 | |
210 | /*! |
211 | * \brief Handle volatile loads. |
212 | * |
213 | * This is to workaround a bug in CUDA cuda_fp16.h. Volatile accesses |
214 | * to shared memory are required for reductions. However, __half class |
215 | * does not implement volatile member functions. CUDA codegen will cast |
216 | * away volatile qualifier from CUDA __half types. |
217 | */ |
218 | virtual void HandleVolatileLoads(const std::string& value, const BufferLoadNode* op, |
219 | std::ostream& os) { |
220 | // By default, do nothing but print the loaded value. |
221 | os << value; |
222 | } |
223 | |
224 | /*! |
225 | * \brief Check if scope is part of type in the target language. |
226 | * |
227 | * **NOTE** In OpenCL, __local is part of type, so "__local int *" |
228 | * is legal. This is not the case for CUDA, where "__shared__" |
229 | * or "__constant__" is not part of type but a storage class (like |
230 | * C/C++ static). |
231 | */ |
232 | virtual bool IsScopePartOfType() const { return true; } |
233 | |
234 | /*! |
235 | * \brief Generate forward function declarations. |
236 | * \param global_symbol The symbolc of the target function. |
237 | * \param args The arguments to the function. |
238 | * \param os The output stream. |
239 | */ |
240 | virtual void GenerateForwardFunctionDeclarations(String global_symbol, |
241 | const Array<PrimExpr>& args) {} |
242 | /*! |
243 | * \brief Print external function call. |
244 | * \param ret_type The return type. |
245 | * \param global_symbol The symbolc of the target function. |
246 | * \param args The arguments to the function. |
247 | * \param skip_first_arg Whether to skip the first arguments. |
248 | * \param os The output stream. |
249 | */ |
250 | virtual void PrintCallExtern(Type ret_type, String global_symbol, const Array<PrimExpr>& args, |
251 | bool skip_first_arg, std::ostream& os); // NOLINT(*) |
252 | /*! |
253 | * \brief If buffer is allocated as type t. |
254 | * \param buf_var The buffer variable. |
255 | * \param t The type to be checked. |
256 | */ |
257 | bool HandleTypeMatch(const VarNode* buf_var, DataType t) const; |
258 | /*! |
259 | * \brief Register the data type of buf_var |
260 | * \param buf_var The buffer variable. |
261 | * \param t The type to be checked. |
262 | */ |
263 | void RegisterHandleType(const VarNode* buf_var, DataType t); |
264 | // override |
265 | void PrintSSAAssign(const std::string& target, const std::string& src, DataType t) final; |
266 | /*! \brief reserves common C keywords */ |
267 | void ReserveKeywordsAsUnique(); |
268 | |
269 | /*! \brief Check if buf_var is volatile or not. */ |
270 | bool IsVolatile(const VarNode* buf_var) const { return volatile_buf_.count(buf_var) != 0; } |
271 | |
272 | /*! \brief restrict keyword */ |
273 | std::string restrict_keyword_{"" }; |
274 | /*! \brief the storage scope of allocation */ |
275 | std::unordered_map<const VarNode*, std::string> alloc_storage_scope_; |
276 | /*! \brief the data type of allocated buffers */ |
277 | std::unordered_map<const VarNode*, DataType> handle_data_type_; |
278 | /*! \brief Record of ops that have pre-defined global symbol. */ |
279 | OpAttrMap<TGlobalSymbol> op_attr_global_symbol_ = Op::GetAttrMap<TGlobalSymbol>("TGlobalSymbol" ); |
280 | // cache commonly used ops |
281 | const Op& builtin_call_extern_ = builtin::call_extern(); |
282 | const Op& builtin_call_pure_extern_ = builtin::call_pure_extern(); |
283 | Integer constants_byte_alignment_ = 16; |
284 | |
285 | private: |
286 | /*! \brief whether to print in SSA form */ |
287 | bool print_ssa_form_{false}; |
288 | /*! \brief set of volatile buf access */ |
289 | std::unordered_set<const VarNode*> volatile_buf_; |
290 | // deep comparison of PrimExpr |
291 | ExprDeepEqual deep_equal_; |
292 | // binding of let variables. Enables duplicate var defs that map to same value |
293 | std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual> let_binding_; |
294 | }; |
295 | |
296 | } // namespace codegen |
297 | } // namespace tvm |
298 | #endif // TVM_TARGET_SOURCE_CODEGEN_C_H_ |
299 | |